diff --git a/internal/net/http/accesslog/access_logger.go b/internal/net/http/accesslog/access_logger.go index 5b95fe8..0656823 100644 --- a/internal/net/http/accesslog/access_logger.go +++ b/internal/net/http/accesslog/access_logger.go @@ -106,23 +106,23 @@ func (l *AccessLogger) Config() *Config { return l.cfg } -// func (l *AccessLogger) Rotate() error { -// if l.cfg.Retention == nil { -// return nil -// } -// l.io.Lock() -// defer l.io.Unlock() +func (l *AccessLogger) Rotate() error { + if l.cfg.Retention == nil { + return nil + } + l.io.Lock() + defer l.io.Unlock() -// return l.cfg.Retention.rotateLogFile(l.io) -// } + return l.cfg.Retention.rotateLogFile(l.io) +} -func (l *AccessLogger) Flush(force bool) { - l.flushMu.Lock() - if force || l.buf.Len() >= l.flushThreshold { +func (l *AccessLogger) Flush() { + if l.buf.Len() >= l.flushThreshold { + l.flushMu.Lock() l.writeLine(l.buf.Bytes()) l.buf.Reset() + l.flushMu.Unlock() } - l.flushMu.Unlock() } func (l *AccessLogger) handleErr(err error) { @@ -138,17 +138,15 @@ func (l *AccessLogger) start() { l.task.Finish(nil) }() - // periodic + threshold flush - flushTicker := time.NewTicker(5 * time.Second) + // threshold flush with periodic check + flushTicker := time.NewTicker(3 * time.Second) for { select { case <-l.task.Context().Done(): return case <-flushTicker.C: - l.Flush(true) - default: - l.Flush(false) + l.Flush() } } } diff --git a/internal/net/http/accesslog/config.go b/internal/net/http/accesslog/config.go index 23d96c8..820c302 100644 --- a/internal/net/http/accesslog/config.go +++ b/internal/net/http/accesslog/config.go @@ -17,12 +17,12 @@ type ( Cookies FieldConfig `json:"cookies"` } Config struct { - BufferSize uint `json:"buffer_size" validate:"gte=1"` - Format Format `json:"format" validate:"oneof=common combined json"` - Path string `json:"path" validate:"required"` - Filters Filters `json:"filters"` - Fields Fields `json:"fields"` - // Retention *Retention + BufferSize uint `json:"buffer_size" validate:"gte=1"` + Format Format `json:"format" validate:"oneof=common combined json"` + Path string `json:"path" validate:"required"` + Filters Filters `json:"filters"` + Fields Fields `json:"fields"` + Retention *Retention `json:"retention"` } ) diff --git a/internal/net/http/accesslog/retention.go b/internal/net/http/accesslog/retention.go new file mode 100644 index 0000000..da31544 --- /dev/null +++ b/internal/net/http/accesslog/retention.go @@ -0,0 +1,198 @@ +package accesslog + +import ( + "bufio" + "bytes" + "io" + "strconv" + "time" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type Retention struct { + Days uint64 `json:"days"` + Last uint64 `json:"last"` +} + +const chunkSizeMax int64 = 128 * 1024 // 128KB + +var ( + ErrInvalidSyntax = E.New("invalid syntax") + ErrZeroValue = E.New("zero value") +) + +// Syntax: +// +// days|weeks|months +// +// last +// +// Parse implements strutils.Parser. +func (r *Retention) Parse(v string) (err error) { + split := strutils.SplitSpace(v) + if len(split) != 2 { + return ErrInvalidSyntax.Subject(v) + } + switch split[0] { + case "last": + r.Last, err = strconv.ParseUint(split[1], 10, 64) + default: // days|weeks|months + r.Days, err = strconv.ParseUint(split[0], 10, 64) + if err != nil { + return + } + switch split[1] { + case "days": + case "weeks": + r.Days *= 7 + case "months": + r.Days *= 30 + default: + return ErrInvalidSyntax.Subject("unit " + split[1]) + } + } + if r.Days == 0 && r.Last == 0 { + return ErrZeroValue + } + return +} + +func (r *Retention) rotateLogFile(file AccessLogIO) (err error) { + lastN := int(r.Last) + days := int(r.Days) + + // Seek to end to get file size + size, err := file.Seek(0, io.SeekEnd) + if err != nil { + return err + } + + // Initialize ring buffer for last N lines + lines := make([][]byte, 0, lastN|(days*1000)) + pos := size + unprocessed := 0 + + var chunk [chunkSizeMax]byte + var lastLine []byte + + var shouldStop func() bool + if days > 0 { + cutoff := time.Now().AddDate(0, 0, -days) + shouldStop = func() bool { + return len(lastLine) > 0 && !parseLogTime(lastLine).After(cutoff) + } + } else { + shouldStop = func() bool { + return len(lines) == lastN + } + } + + // Read backwards until we have enough lines or reach start of file + for pos > 0 { + if pos > chunkSizeMax { + pos -= chunkSizeMax + } else { + pos = 0 + } + + // Seek to the current chunk + if _, err = file.Seek(pos, io.SeekStart); err != nil { + return err + } + + var nRead int + // Read the chunk + if nRead, err = file.Read(chunk[unprocessed:]); err != nil { + return err + } + + // last unprocessed bytes + read bytes + curChunk := chunk[:unprocessed+nRead] + unprocessed = len(curChunk) + + // Split into lines + scanner := bufio.NewScanner(bytes.NewReader(curChunk)) + for !shouldStop() && scanner.Scan() { + lastLine = scanner.Bytes() + lines = append(lines, lastLine) + unprocessed -= len(lastLine) + } + if shouldStop() { + break + } + + // move unprocessed bytes to the beginning for next iteration + copy(chunk[:], curChunk[unprocessed:]) + } + + if days > 0 { + // truncate to the end of the log within last N days + return file.Truncate(pos) + } + + // write lines to buffer in reverse order + // since we read them backwards + var buf bytes.Buffer + for i := len(lines) - 1; i >= 0; i-- { + buf.Write(lines[i]) + buf.WriteRune('\n') + } + + return writeTruncate(file, &buf) +} + +func writeTruncate(file AccessLogIO, buf *bytes.Buffer) (err error) { + // Seek to beginning and truncate + if _, err := file.Seek(0, 0); err != nil { + return err + } + + buffered := bufio.NewWriter(file) + // Write buffer back to file + nWritten, err := buffered.Write(buf.Bytes()) + if err != nil { + return err + } + if err = buffered.Flush(); err != nil { + return err + } + + // Truncate file + if err = file.Truncate(int64(nWritten)); err != nil { + return err + } + + // check bytes written == buffer size + if nWritten != buf.Len() { + return io.ErrShortWrite + } + return +} + +func parseLogTime(line []byte) (t time.Time) { + if len(line) == 0 { + return + } + + var start, end int + const jsonStart = len(`{"time":"`) + const jsonEnd = jsonStart + len(LogTimeFormat) + + if len(line) == '{' { // possibly json log + start = jsonStart + end = jsonEnd + } else { // possibly common or combined format + // Format: - - [02/Jan/2006:15:04:05 -0700] ... + start = bytes.IndexRune(line, '[') + end = bytes.IndexRune(line[start+1:], ']') + if start == -1 || end == -1 || start >= end { + return + } + } + + timeStr := line[start+1 : end] + t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error + return +} diff --git a/internal/net/http/accesslog/retention_test.go b/internal/net/http/accesslog/retention_test.go new file mode 100644 index 0000000..606c51c --- /dev/null +++ b/internal/net/http/accesslog/retention_test.go @@ -0,0 +1,148 @@ +package accesslog_test + +import ( + "bytes" + "io" + "testing" + "time" + + . "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/utils/strutils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestParseRetention(t *testing.T) { + tests := []struct { + input string + expected *Retention + shouldErr bool + }{ + {"30 days", &Retention{Days: 30}, false}, + {"2 weeks", &Retention{Days: 14}, false}, + {"last 5", &Retention{Last: 5}, false}, + {"invalid input", &Retention{}, true}, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + r := &Retention{} + err := r.Parse(test.input) + if !test.shouldErr { + ExpectNoError(t, err) + } else { + ExpectDeepEqual(t, r, test.expected) + } + }) + } +} + +type mockFile struct { + data []byte + position int64 +} + +func (m *mockFile) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + m.position = offset + case io.SeekCurrent: + m.position += offset + case io.SeekEnd: + m.position = int64(len(m.data)) + offset + } + return m.position, nil +} + +func (m *mockFile) Write(p []byte) (n int, err error) { + m.data = append(m.data, p...) + n = len(p) + m.position += int64(n) + return +} + +func (m *mockFile) Name() string { + return "mock" +} + +func (m *mockFile) Read(p []byte) (n int, err error) { + if m.position >= int64(len(m.data)) { + return 0, io.EOF + } + n = copy(p, m.data[m.position:]) + m.position += int64(n) + return n, nil +} + +func (m *mockFile) ReadAt(p []byte, off int64) (n int, err error) { + if off >= int64(len(m.data)) { + return 0, io.EOF + } + n = copy(p, m.data[off:]) + m.position += int64(n) + return n, nil +} + +func (m *mockFile) Close() error { + return nil +} + +func (m *mockFile) Truncate(size int64) error { + m.data = m.data[:size] + m.position = size + return nil +} + +func (m *mockFile) Lock() {} +func (m *mockFile) Unlock() {} + +func (m *mockFile) Count() int { + return bytes.Count(m.data[:m.position], []byte("\n")) +} + +func (m *mockFile) Len() int64 { + return m.position +} + +func TestRetentionCommonFormat(t *testing.T) { + file := mockFile{} + logger := NewAccessLogger(nil, &file, &Config{ + Format: FormatCommon, + BufferSize: 1024, + }) + for range 10 { + logger.Log(req, resp) + } + logger.Flush(true) + // test.Finish(nil) + + ExpectEqual(t, logger.Config().Retention, nil) + ExpectTrue(t, file.Len() > 0) + ExpectEqual(t, file.Count(), 10) + + t.Run("keep last", func(t *testing.T) { + logger.Config().Retention = strutils.MustParse[*Retention]("last 5") + ExpectEqual(t, logger.Config().Retention.Days, 0) + ExpectEqual(t, logger.Config().Retention.Last, 5) + ExpectNoError(t, logger.Rotate()) + ExpectEqual(t, file.Count(), 5) + }) + + _ = file.Truncate(0) + + timeNow := time.Now() + for i := range 10 { + logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time { + return timeNow.AddDate(0, 0, -i) + } + logger.Log(req, resp) + } + logger.Flush(true) + + t.Run("keep days", func(t *testing.T) { + logger.Config().Retention = strutils.MustParse[*Retention]("3 days") + ExpectEqual(t, logger.Config().Retention.Days, 3) + ExpectEqual(t, logger.Config().Retention.Last, 0) + ExpectNoError(t, logger.Rotate()) + ExpectEqual(t, file.Count(), 3) + }) +} diff --git a/internal/utils/ref_count.go b/internal/utils/ref_count.go index a61a68e..ac6eb01 100644 --- a/internal/utils/ref_count.go +++ b/internal/utils/ref_count.go @@ -1,15 +1,12 @@ package utils import ( - "sync" "sync/atomic" ) type RefCount struct { _ NoCopy - mu sync.Mutex - cond *sync.Cond refCount uint32 zeroCh chan struct{} } @@ -19,7 +16,6 @@ func NewRefCounter() *RefCount { refCount: 1, zeroCh: make(chan struct{}), } - rc.cond = sync.NewCond(&rc.mu) return rc } @@ -33,9 +29,6 @@ func (rc *RefCount) Add() { func (rc *RefCount) Sub() { if atomic.AddUint32(&rc.refCount, ^uint32(0)) == 0 { - rc.mu.Lock() close(rc.zeroCh) - rc.cond.Broadcast() - rc.mu.Unlock() } }