From 3ecc0f95bf10b8f510ecff8ec7e9f5741d9ee161 Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 3 Jan 2025 16:31:49 +0800 Subject: [PATCH] fixed some tests --- internal/net/http/accesslog/access_logger.go | 9 +++++---- internal/net/http/accesslog/access_logger_test.go | 12 +++++++----- internal/net/http/accesslog/formatter.go | 11 ++++++++--- internal/net/http/accesslog/retention_test.go | 1 + internal/utils/ref_count_test.go | 10 ++++++---- 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/internal/net/http/accesslog/access_logger.go b/internal/net/http/accesslog/access_logger.go index 4f67717..95d503f 100644 --- a/internal/net/http/accesslog/access_logger.go +++ b/internal/net/http/accesslog/access_logger.go @@ -39,6 +39,7 @@ type ( Formatter interface { // Format writes a log line to line without a trailing newline Format(line *bytes.Buffer, req *http.Request, res *http.Response) + SetGetTimeNow(getTimeNow func() time.Time) } ) @@ -54,14 +55,14 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog cfg.BufferSize = DefaultBufferSize } - fmt := &CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now} + fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now} switch l.cfg.Format { case FormatCommon: - l.Formatter = fmt + l.Formatter = &fmt case FormatCombined: - l.Formatter = (*CombinedFormatter)(fmt) + l.Formatter = &CombinedFormatter{fmt} case FormatJSON: - l.Formatter = (*JSONFormatter)(fmt) + l.Formatter = &JSONFormatter{fmt} default: // should not happen, validation has done by validate tags panic("invalid access log format") } diff --git a/internal/net/http/accesslog/access_logger_test.go b/internal/net/http/accesslog/access_logger_test.go index 382dc7e..bac8db8 100644 --- a/internal/net/http/accesslog/access_logger_test.go +++ b/internal/net/http/accesslog/access_logger_test.go @@ -11,6 +11,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/task" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -28,8 +29,9 @@ const ( ) var ( - testURL = E.Must(url.Parse("http://" + host + uri)) - req = &http.Request{ + testTask = task.RootTask("test", false) + testURL = E.Must(url.Parse("http://" + host + uri)) + req = &http.Request{ RemoteAddr: remote, Method: method, Proto: proto, @@ -55,10 +57,10 @@ func fmtLog(cfg *Config) (ts string, line string) { var buf bytes.Buffer t := time.Now() - logger := NewAccessLogger(nil, nil, cfg) - logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time { + logger := NewAccessLogger(testTask, nil, cfg) + logger.Formatter.SetGetTimeNow(func() time.Time { return t - } + }) logger.Format(&buf, req, resp) return t.Format(LogTimeFormat), buf.String() } diff --git a/internal/net/http/accesslog/formatter.go b/internal/net/http/accesslog/formatter.go index 0a02b96..30bf80d 100644 --- a/internal/net/http/accesslog/formatter.go +++ b/internal/net/http/accesslog/formatter.go @@ -15,8 +15,8 @@ type ( cfg *Fields GetTimeNow func() time.Time // for testing purposes only } - CombinedFormatter CommonFormatter - JSONFormatter CommonFormatter + CombinedFormatter struct{ CommonFormatter } + JSONFormatter struct{ CommonFormatter } JSONLogEntry struct { Time string `json:"time"` @@ -63,6 +63,11 @@ func clientIP(req *http.Request) string { return req.RemoteAddr } +// debug only. +func (f *CommonFormatter) SetGetTimeNow(getTimeNow func() time.Time) { + f.GetTimeNow = getTimeNow +} + func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { query := f.cfg.Query.ProcessQuery(req.URL.Query()) @@ -88,7 +93,7 @@ func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *htt } func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { - (*CommonFormatter)(f).Format(line, req, res) + f.CommonFormatter.Format(line, req, res) line.WriteString(" \"") line.WriteString(req.Referer()) line.WriteString("\" \"") diff --git a/internal/net/http/accesslog/retention_test.go b/internal/net/http/accesslog/retention_test.go index 167cbc5..ea51fe2 100644 --- a/internal/net/http/accesslog/retention_test.go +++ b/internal/net/http/accesslog/retention_test.go @@ -70,6 +70,7 @@ func TestRetentionCommonFormat(t *testing.T) { } logger.Flush(true) + // FIXME: keep days does not work t.Run("keep days", func(t *testing.T) { logger.Config().Retention = strutils.MustParse[*Retention]("3 days") ExpectEqual(t, logger.Config().Retention.Days, 3) diff --git a/internal/utils/ref_count_test.go b/internal/utils/ref_count_test.go index 9d0a2aa..c147638 100644 --- a/internal/utils/ref_count_test.go +++ b/internal/utils/ref_count_test.go @@ -6,8 +6,8 @@ import ( "time" ) -func TestRefCounter_AddSub(t *testing.T) { - rc := NewRefCounter() +func TestRefCounterAddSub(t *testing.T) { + rc := NewRefCounter() // Count starts at 1 var wg sync.WaitGroup wg.Add(2) @@ -20,6 +20,7 @@ func TestRefCounter_AddSub(t *testing.T) { go func() { defer wg.Done() rc.Sub() + rc.Sub() }() wg.Wait() @@ -32,7 +33,7 @@ func TestRefCounter_AddSub(t *testing.T) { } } -func TestRefCounter_MultipleAddSub(t *testing.T) { +func TestRefCounterMultipleAddSub(t *testing.T) { rc := NewRefCounter() var wg sync.WaitGroup @@ -51,6 +52,7 @@ func TestRefCounter_MultipleAddSub(t *testing.T) { go func() { defer wg.Done() rc.Sub() + rc.Sub() }() } @@ -64,7 +66,7 @@ func TestRefCounter_MultipleAddSub(t *testing.T) { } } -func TestRefCounter_ZeroInitially(t *testing.T) { +func TestRefCounterOneInitially(t *testing.T) { rc := NewRefCounter() rc.Sub() // Bring count to zero