From 82e2705f4406cd87115dd1b8f9719242a424c5d8 Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 14 Apr 2025 07:15:15 +0800 Subject: [PATCH] feat: stdout access logger and MultiWriter --- internal/entrypoint/entrypoint.go | 2 +- .../net/gphttp/accesslog/access_logger.go | 40 ++++++++++++++-- .../gphttp/accesslog/access_logger_test.go | 2 +- internal/net/gphttp/accesslog/back_scanner.go | 4 +- internal/net/gphttp/accesslog/config.go | 16 ++++++- internal/net/gphttp/accesslog/file_logger.go | 11 ++--- .../net/gphttp/accesslog/file_logger_test.go | 7 ++- internal/net/gphttp/accesslog/multi_writer.go | 46 +++++++++++++++++++ internal/net/gphttp/accesslog/rotate.go | 23 +++++++--- internal/net/gphttp/accesslog/rotate_test.go | 2 +- .../net/gphttp/accesslog/stdout_logger.go | 18 ++++++++ internal/route/fileserver.go | 2 +- internal/route/reverse_proxy.go | 2 +- 13 files changed, 145 insertions(+), 30 deletions(-) create mode 100644 internal/net/gphttp/accesslog/multi_writer.go create mode 100644 internal/net/gphttp/accesslog/stdout_logger.go diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 1a3a166..65e8d07 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -61,7 +61,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) return } - ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) + ep.accessLogger, err = accesslog.NewAccessLogger(parent, cfg) if err != nil { return } diff --git a/internal/net/gphttp/accesslog/access_logger.go b/internal/net/gphttp/accesslog/access_logger.go index 9a6fc8a..1637377 100644 --- a/internal/net/gphttp/accesslog/access_logger.go +++ b/internal/net/gphttp/accesslog/access_logger.go @@ -25,11 +25,15 @@ type ( } AccessLogIO interface { + io.Writer + sync.Locker + Name() string // file name or path + } + + supportRotate interface { io.ReadWriteCloser io.ReadWriteSeeker io.ReaderAt - sync.Locker - Name() string // file name or path Truncate(size int64) error } @@ -40,7 +44,33 @@ type ( } ) -func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { +func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { + var ios []AccessLogIO + + if cfg.Stdout { + ios = append(ios, stdoutIO) + } + + if cfg.Path != "" { + io, err := newFileIO(cfg.Path) + if err != nil { + return nil, err + } + ios = append(ios, io) + } + + if len(ios) == 0 { + return nil, nil + } + + return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil +} + +func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger { + return NewAccessLoggerWithIO(parent, &MockFile{}, cfg) +} + +func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger { if cfg.BufferSize == 0 { cfg.BufferSize = DefaultBufferSize } @@ -152,7 +182,9 @@ func (l *AccessLogger) Flush() error { func (l *AccessLogger) close() { l.io.Lock() defer l.io.Unlock() - l.io.Close() + if r, ok := l.io.(io.Closer); ok { + r.Close() + } } func (l *AccessLogger) write(data []byte) { diff --git a/internal/net/gphttp/accesslog/access_logger_test.go b/internal/net/gphttp/accesslog/access_logger_test.go index 012d8eb..7e7af45 100644 --- a/internal/net/gphttp/accesslog/access_logger_test.go +++ b/internal/net/gphttp/accesslog/access_logger_test.go @@ -56,7 +56,7 @@ func fmtLog(cfg *Config) (ts string, line string) { var buf bytes.Buffer t := time.Now() - logger := NewAccessLogger(testTask, nil, cfg) + logger := NewMockAccessLogger(testTask, cfg) logger.Formatter.SetGetTimeNow(func() time.Time { return t }) diff --git a/internal/net/gphttp/accesslog/back_scanner.go b/internal/net/gphttp/accesslog/back_scanner.go index 2e55005..c05692b 100644 --- a/internal/net/gphttp/accesslog/back_scanner.go +++ b/internal/net/gphttp/accesslog/back_scanner.go @@ -7,7 +7,7 @@ import ( // BackScanner provides an interface to read a file backward line by line. type BackScanner struct { - file AccessLogIO + file supportRotate chunkSize int offset int64 buffer []byte @@ -18,7 +18,7 @@ type BackScanner struct { // NewBackScanner creates a new Scanner to read the file backward. // chunkSize determines the size of each read chunk from the end of the file. -func NewBackScanner(file AccessLogIO, chunkSize int) *BackScanner { +func NewBackScanner(file supportRotate, chunkSize int) *BackScanner { size, err := file.Seek(0, io.SeekEnd) if err != nil { return &BackScanner{err: err} diff --git a/internal/net/gphttp/accesslog/config.go b/internal/net/gphttp/accesslog/config.go index a1dbe2f..9a7eb46 100644 --- a/internal/net/gphttp/accesslog/config.go +++ b/internal/net/gphttp/accesslog/config.go @@ -1,6 +1,10 @@ package accesslog -import "github.com/yusing/go-proxy/internal/utils" +import ( + "errors" + + "github.com/yusing/go-proxy/internal/utils" +) type ( Format string @@ -19,7 +23,8 @@ type ( Config struct { BufferSize int `json:"buffer_size"` Format Format `json:"format" validate:"oneof=common combined json"` - Path string `json:"path" validate:"required"` + Path string `json:"path"` + Stdout bool `json:"stdout"` Filters Filters `json:"filters"` Fields Fields `json:"fields"` Retention *Retention `json:"retention"` @@ -34,6 +39,13 @@ var ( const DefaultBufferSize = 64 * 1024 // 64KB +func (cfg *Config) Validate() error { + if cfg.Path == "" && !cfg.Stdout { + return errors.New("path or stdout is required") + } + return nil +} + func DefaultConfig() *Config { return &Config{ BufferSize: DefaultBufferSize, diff --git a/internal/net/gphttp/accesslog/file_logger.go b/internal/net/gphttp/accesslog/file_logger.go index 1b3ace1..a3679ac 100644 --- a/internal/net/gphttp/accesslog/file_logger.go +++ b/internal/net/gphttp/accesslog/file_logger.go @@ -3,11 +3,10 @@ package accesslog import ( "fmt" "os" - "path" + pathPkg "path" "sync" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" ) @@ -27,16 +26,16 @@ var ( openedFilesMu sync.Mutex ) -func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { +func newFileIO(path string) (AccessLogIO, error) { openedFilesMu.Lock() var file *File - path := path.Clean(cfg.Path) + path = pathPkg.Clean(path) if opened, ok := openedFiles[path]; ok { opened.refCount.Add() file = opened } else { - f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) if err != nil { openedFilesMu.Unlock() return nil, fmt.Errorf("access log open error: %w", err) @@ -47,7 +46,7 @@ func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) } openedFilesMu.Unlock() - return NewAccessLogger(parent, file, cfg), nil + return file, nil } func (f *File) Close() error { diff --git a/internal/net/gphttp/accesslog/file_logger_test.go b/internal/net/gphttp/accesslog/file_logger_test.go index 0321a85..5159d01 100644 --- a/internal/net/gphttp/accesslog/file_logger_test.go +++ b/internal/net/gphttp/accesslog/file_logger_test.go @@ -16,7 +16,6 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { cfg := DefaultConfig() cfg.Path = "test.log" - parent := task.RootTask("test", false) loggerCount := 10 accessLogIOs := make([]AccessLogIO, loggerCount) @@ -33,9 +32,9 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { wg.Add(1) go func(index int) { defer wg.Done() - logger, err := NewFileAccessLogger(parent, cfg) + file, err := newFileIO(cfg.Path) ExpectNoError(t, err) - accessLogIOs[index] = logger.io + accessLogIOs[index] = file }(i) } @@ -59,7 +58,7 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) { loggers := make([]*AccessLogger, loggerCount) for i := range loggerCount { - loggers[i] = NewAccessLogger(parent, &file, cfg) + loggers[i] = NewAccessLoggerWithIO(parent, &file, cfg) } var wg sync.WaitGroup diff --git a/internal/net/gphttp/accesslog/multi_writer.go b/internal/net/gphttp/accesslog/multi_writer.go new file mode 100644 index 0000000..3577bc4 --- /dev/null +++ b/internal/net/gphttp/accesslog/multi_writer.go @@ -0,0 +1,46 @@ +package accesslog + +import "strings" + +type MultiWriter struct { + writers []AccessLogIO +} + +func NewMultiWriter(writers ...AccessLogIO) AccessLogIO { + if len(writers) == 0 { + return nil + } + if len(writers) == 1 { + return writers[0] + } + return &MultiWriter{ + writers: writers, + } +} + +func (w *MultiWriter) Write(p []byte) (n int, err error) { + for _, writer := range w.writers { + writer.Write(p) + } + return len(p), nil +} + +func (w *MultiWriter) Lock() { + for _, writer := range w.writers { + writer.Lock() + } +} + +func (w *MultiWriter) Unlock() { + for _, writer := range w.writers { + writer.Unlock() + } +} + +func (w *MultiWriter) Name() string { + names := make([]string, len(w.writers)) + for i, writer := range w.writers { + names[i] = writer.Name() + } + return strings.Join(names, ", ") +} diff --git a/internal/net/gphttp/accesslog/rotate.go b/internal/net/gphttp/accesslog/rotate.go index e93c22d..2f3e92e 100644 --- a/internal/net/gphttp/accesslog/rotate.go +++ b/internal/net/gphttp/accesslog/rotate.go @@ -2,11 +2,15 @@ package accesslog import ( "bytes" - "io" + ioPkg "io" "time" ) func (l *AccessLogger) rotate() (err error) { + io, ok := l.io.(supportRotate) + if !ok { + return nil + } // Get retention configuration config := l.Config().Retention var shouldKeep func(t time.Time, lineCount int) bool @@ -24,7 +28,7 @@ func (l *AccessLogger) rotate() (err error) { return nil // No retention policy set } - s := NewBackScanner(l.io, defaultChunkSize) + s := NewBackScanner(io, defaultChunkSize) nRead := 0 nLines := 0 for s.Scan() { @@ -40,11 +44,11 @@ func (l *AccessLogger) rotate() (err error) { } beg := int64(nRead) - if _, err := l.io.Seek(-beg, io.SeekEnd); err != nil { + if _, err := io.Seek(-beg, ioPkg.SeekEnd); err != nil { return err } buf := make([]byte, nRead) - if _, err := l.io.Read(buf); err != nil { + if _, err := io.Read(buf); err != nil { return err } @@ -55,8 +59,13 @@ func (l *AccessLogger) rotate() (err error) { } func (l *AccessLogger) writeTruncate(buf []byte) (err error) { + io, ok := l.io.(supportRotate) + if !ok { + return nil + } + // Seek to beginning and truncate - if _, err := l.io.Seek(0, 0); err != nil { + if _, err := io.Seek(0, 0); err != nil { return err } @@ -70,13 +79,13 @@ func (l *AccessLogger) writeTruncate(buf []byte) (err error) { } // Truncate file - if err = l.io.Truncate(int64(nWritten)); err != nil { + if err = io.Truncate(int64(nWritten)); err != nil { return err } // check bytes written == buffer size if nWritten != len(buf) { - return io.ErrShortWrite + return ioPkg.ErrShortWrite } return } diff --git a/internal/net/gphttp/accesslog/rotate_test.go b/internal/net/gphttp/accesslog/rotate_test.go index 8b81792..727a3cb 100644 --- a/internal/net/gphttp/accesslog/rotate_test.go +++ b/internal/net/gphttp/accesslog/rotate_test.go @@ -33,7 +33,7 @@ func TestParseLogTime(t *testing.T) { func TestRetentionCommonFormat(t *testing.T) { var file MockFile - logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{ + logger := NewAccessLoggerWithIO(task.RootTask("test", false), &file, &Config{ Format: FormatCommon, BufferSize: 1024, }) diff --git a/internal/net/gphttp/accesslog/stdout_logger.go b/internal/net/gphttp/accesslog/stdout_logger.go new file mode 100644 index 0000000..2e1f245 --- /dev/null +++ b/internal/net/gphttp/accesslog/stdout_logger.go @@ -0,0 +1,18 @@ +package accesslog + +import ( + "io" + "os" +) + +type StdoutLogger struct { + io.Writer +} + +var stdoutIO = &StdoutLogger{os.Stdout} + +func (l *StdoutLogger) Lock() {} +func (l *StdoutLogger) Unlock() {} +func (l *StdoutLogger) Name() string { + return "stdout" +} diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index ed75e9b..e8e8154 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -84,7 +84,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error { if s.UseAccessLog() { var err error - s.accessLogger, err = accesslog.NewFileAccessLogger(s.task, s.AccessLog) + s.accessLogger, err = accesslog.NewAccessLogger(s.task, s.AccessLog) if err != nil { s.task.Finish(err) return gperr.Wrap(err) diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index 817b778..f1cce4c 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -116,7 +116,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { if r.UseAccessLog() { var err error - r.rp.AccessLogger, err = accesslog.NewFileAccessLogger(r.task, r.AccessLog) + r.rp.AccessLogger, err = accesslog.NewAccessLogger(r.task, r.AccessLog) if err != nil { r.task.Finish(err) return gperr.Wrap(err)