From 6e30d39b7831dd8f9a936dc2d5c68de7fc59ead9 Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 3 Jan 2025 14:10:09 +0800 Subject: [PATCH] access logger support sharing the same file, tests added for concurrent logging --- internal/net/http/accesslog/access_logger.go | 35 +++---- internal/net/http/accesslog/file_logger.go | 24 ++++- .../net/http/accesslog/file_logger_test.go | 95 +++++++++++++++++++ internal/net/http/accesslog/mock_file.go | 74 +++++++++++++++ internal/net/http/accesslog/retention_test.go | 74 +-------------- 5 files changed, 211 insertions(+), 91 deletions(-) create mode 100644 internal/net/http/accesslog/file_logger_test.go create mode 100644 internal/net/http/accesslog/mock_file.go diff --git a/internal/net/http/accesslog/access_logger.go b/internal/net/http/accesslog/access_logger.go index 0656823..4f67717 100644 --- a/internal/net/http/accesslog/access_logger.go +++ b/internal/net/http/accesslog/access_logger.go @@ -18,10 +18,11 @@ type ( cfg *Config io AccessLogIO - buf bytes.Buffer - bufPool sync.Pool + buf bytes.Buffer // buffer for non-flushed log + bufMu sync.Mutex // protect buf + bufPool sync.Pool // buffer pool for formatting a single log line + flushThreshold int - flushMu sync.Mutex Formatter } @@ -61,6 +62,8 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog l.Formatter = (*CombinedFormatter)(fmt) case FormatJSON: l.Formatter = (*JSONFormatter)(fmt) + default: // should not happen, validation has done by validate tags + panic("invalid access log format") } l.flushThreshold = int(cfg.BufferSize * 4 / 5) // 80% @@ -91,11 +94,11 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) { l.Format(line, req, res) line.WriteRune('\n') - l.flushMu.Lock() + l.bufMu.Lock() l.buf.Write(line.Bytes()) line.Reset() l.bufPool.Put(line) - l.flushMu.Unlock() + l.bufMu.Unlock() } func (l *AccessLogger) LogError(req *http.Request, err error) { @@ -116,12 +119,12 @@ func (l *AccessLogger) Rotate() error { return l.cfg.Retention.rotateLogFile(l.io) } -func (l *AccessLogger) Flush() { - if l.buf.Len() >= l.flushThreshold { - l.flushMu.Lock() - l.writeLine(l.buf.Bytes()) +func (l *AccessLogger) Flush(force bool) { + if force || l.buf.Len() >= l.flushThreshold { + l.bufMu.Lock() + l.write(l.buf.Bytes()) l.buf.Reset() - l.flushMu.Unlock() + l.bufMu.Unlock() } } @@ -132,28 +135,28 @@ func (l *AccessLogger) handleErr(err error) { func (l *AccessLogger) start() { defer func() { if l.buf.Len() > 0 { // flush last - l.writeLine(l.buf.Bytes()) + l.write(l.buf.Bytes()) } l.io.Close() l.task.Finish(nil) }() // threshold flush with periodic check - flushTicker := time.NewTicker(3 * time.Second) + flushTicker := time.NewTicker(time.Second) for { select { case <-l.task.Context().Done(): return case <-flushTicker.C: - l.Flush() + l.Flush(false) } } } -func (l *AccessLogger) writeLine(line []byte) { - l.io.Lock() // prevent write on log rotation - _, err := l.io.Write(line) +func (l *AccessLogger) write(data []byte) { + l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers + _, err := l.io.Write(data) l.io.Unlock() if err != nil { l.handleErr(err) diff --git a/internal/net/http/accesslog/file_logger.go b/internal/net/http/accesslog/file_logger.go index aeaa241..7b3aec7 100644 --- a/internal/net/http/accesslog/file_logger.go +++ b/internal/net/http/accesslog/file_logger.go @@ -13,10 +13,26 @@ type File struct { sync.Mutex } +var ( + openedFiles = make(map[string]AccessLogIO) + openedFilesMu sync.Mutex +) + func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { - f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return nil, fmt.Errorf("access log open error: %w", err) + openedFilesMu.Lock() + + var io AccessLogIO + if opened, ok := openedFiles[cfg.Path]; ok { + io = opened + } else { + f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return nil, fmt.Errorf("access log open error: %w", err) + } + io = &File{File: f} + openedFiles[cfg.Path] = io } - return NewAccessLogger(parent, &File{File: f}, cfg), nil + + openedFilesMu.Unlock() + return NewAccessLogger(parent, io, cfg), nil } diff --git a/internal/net/http/accesslog/file_logger_test.go b/internal/net/http/accesslog/file_logger_test.go new file mode 100644 index 0000000..ffa7aab --- /dev/null +++ b/internal/net/http/accesslog/file_logger_test.go @@ -0,0 +1,95 @@ +package accesslog + +import ( + "net/http" + "os" + "sync" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" + + "github.com/yusing/go-proxy/internal/task" +) + +func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { + var wg sync.WaitGroup + + cfg := DefaultConfig() + cfg.Path = "test.log" + parent := task.RootTask("test", false) + + loggerCount := 10 + accessLogIOs := make([]AccessLogIO, loggerCount) + + // make test log file + file, err := os.Create(cfg.Path) + ExpectNoError(t, err) + file.Close() + t.Cleanup(func() { + ExpectNoError(t, os.Remove(cfg.Path)) + }) + + for i := range loggerCount { + wg.Add(1) + go func(index int) { + defer wg.Done() + logger, err := NewFileAccessLogger(parent, cfg) + ExpectNoError(t, err) + accessLogIOs[index] = logger.io + }(i) + } + + wg.Wait() + + firstIO := accessLogIOs[0] + for _, io := range accessLogIOs { + ExpectEqual(t, io, firstIO) + } +} + +func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { + var file MockFile + + cfg := DefaultConfig() + cfg.BufferSize = 1024 + parent := task.RootTask("test", false) + + loggerCount := 5 + logCountPerLogger := 10 + loggers := make([]*AccessLogger, loggerCount) + + for i := range loggerCount { + loggers[i] = NewAccessLogger(parent, &file, cfg) + } + + var wg sync.WaitGroup + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + resp := &http.Response{StatusCode: http.StatusOK} + + for _, logger := range loggers { + wg.Add(1) + go func(l *AccessLogger) { + defer wg.Done() + parallelLog(l, req, resp, logCountPerLogger) + l.Flush(true) + }(logger) + } + + wg.Wait() + + expected := loggerCount * logCountPerLogger + actual := file.Count() + ExpectEqual(t, actual, expected) +} + +func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) { + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + logger.Log(req, resp) + }() + } + wg.Wait() +} diff --git a/internal/net/http/accesslog/mock_file.go b/internal/net/http/accesslog/mock_file.go new file mode 100644 index 0000000..f960429 --- /dev/null +++ b/internal/net/http/accesslog/mock_file.go @@ -0,0 +1,74 @@ +package accesslog + +import ( + "bytes" + "io" + "sync" +) + +type MockFile struct { + data []byte + position int64 + sync.Mutex +} + +func (m *MockFile) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + m.position = offset + case io.SeekCurrent: + m.position += offset + case io.SeekEnd: + m.position = int64(len(m.data)) + offset + } + return m.position, nil +} + +func (m *MockFile) Write(p []byte) (n int, err error) { + m.data = append(m.data, p...) + n = len(p) + m.position += int64(n) + return +} + +func (m *MockFile) Name() string { + return "mock" +} + +func (m *MockFile) Read(p []byte) (n int, err error) { + if m.position >= int64(len(m.data)) { + return 0, io.EOF + } + n = copy(p, m.data[m.position:]) + m.position += int64(n) + return n, nil +} + +func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) { + if off >= int64(len(m.data)) { + return 0, io.EOF + } + n = copy(p, m.data[off:]) + m.position += int64(n) + return n, nil +} + +func (m *MockFile) Close() error { + return nil +} + +func (m *MockFile) Truncate(size int64) error { + m.data = m.data[:size] + m.position = size + return nil +} + +func (m *MockFile) Count() int { + m.Lock() + defer m.Unlock() + return bytes.Count(m.data[:m.position], []byte("\n")) +} + +func (m *MockFile) Len() int64 { + return m.position +} diff --git a/internal/net/http/accesslog/retention_test.go b/internal/net/http/accesslog/retention_test.go index 606c51c..167cbc5 100644 --- a/internal/net/http/accesslog/retention_test.go +++ b/internal/net/http/accesslog/retention_test.go @@ -1,12 +1,11 @@ package accesslog_test import ( - "bytes" - "io" "testing" "time" . "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -36,76 +35,9 @@ func TestParseRetention(t *testing.T) { } } -type mockFile struct { - data []byte - position int64 -} - -func (m *mockFile) Seek(offset int64, whence int) (int64, error) { - switch whence { - case io.SeekStart: - m.position = offset - case io.SeekCurrent: - m.position += offset - case io.SeekEnd: - m.position = int64(len(m.data)) + offset - } - return m.position, nil -} - -func (m *mockFile) Write(p []byte) (n int, err error) { - m.data = append(m.data, p...) - n = len(p) - m.position += int64(n) - return -} - -func (m *mockFile) Name() string { - return "mock" -} - -func (m *mockFile) Read(p []byte) (n int, err error) { - if m.position >= int64(len(m.data)) { - return 0, io.EOF - } - n = copy(p, m.data[m.position:]) - m.position += int64(n) - return n, nil -} - -func (m *mockFile) ReadAt(p []byte, off int64) (n int, err error) { - if off >= int64(len(m.data)) { - return 0, io.EOF - } - n = copy(p, m.data[off:]) - m.position += int64(n) - return n, nil -} - -func (m *mockFile) Close() error { - return nil -} - -func (m *mockFile) Truncate(size int64) error { - m.data = m.data[:size] - m.position = size - return nil -} - -func (m *mockFile) Lock() {} -func (m *mockFile) Unlock() {} - -func (m *mockFile) Count() int { - return bytes.Count(m.data[:m.position], []byte("\n")) -} - -func (m *mockFile) Len() int64 { - return m.position -} - func TestRetentionCommonFormat(t *testing.T) { - file := mockFile{} - logger := NewAccessLogger(nil, &file, &Config{ + var file MockFile + logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{ Format: FormatCommon, BufferSize: 1024, })