merge: access log rotation and enhancements

This commit is contained in:
yusing 2025-04-24 15:29:18 +08:00
parent d668b03175
commit 31812430f1
29 changed files with 1600 additions and 581 deletions

12
go.mod
View file

@ -33,7 +33,9 @@ require (
github.com/bytedance/sonic v1.13.2
github.com/docker/cli v28.1.1+incompatible
github.com/luthermonson/go-proxmox v0.2.2
github.com/spf13/afero v1.14.0
github.com/stretchr/testify v1.10.0
go.uber.org/atomic v1.11.0
)
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e
@ -49,7 +51,7 @@ require (
github.com/cloudflare/cloudflare-go v0.115.0 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/diskfs/go-diskfs v1.5.0 // indirect
github.com/diskfs/go-diskfs v1.6.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/djherbis/times v1.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect
@ -64,11 +66,11 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/jinzhu/copier v0.3.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/jinzhu/copier v0.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/magefile/mage v1.15.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.65 // indirect
@ -93,7 +95,7 @@ require (
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/arch v0.16.0 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sys v0.32.0 // indirect

24
go.sum
View file

@ -33,8 +33,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/diskfs/go-diskfs v1.5.0 h1:0SANkrab4ifiZBytk380gIesYh5Gc+3i40l7qsrYP4s=
github.com/diskfs/go-diskfs v1.5.0/go.mod h1:bRFumZeGFCO8C2KNswrQeuj2m1WCVr4Ms5IjWMczMDk=
github.com/diskfs/go-diskfs v1.6.0 h1:YmK5+vLSfkwC6kKKRTRPGaDGNF+Xh8FXeiNHwryDfu4=
github.com/diskfs/go-diskfs v1.6.0/go.mod h1:bRFumZeGFCO8C2KNswrQeuj2m1WCVr4Ms5IjWMczMDk=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c=
@ -107,15 +107,15 @@ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslC
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nuc=
github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/jinzhu/copier v0.3.4 h1:mfU6jI9PtCeUjkjQ322dlff9ELjGDu975C2p/nrubVI=
github.com/jinzhu/copier v0.3.4/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@ -131,8 +131,8 @@ github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr32
github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg=
github.com/luthermonson/go-proxmox v0.2.2 h1:BZ7VEj302wxw2i/EwTcyEiBzQib8teocB2SSkLHyySY=
github.com/luthermonson/go-proxmox v0.2.2/go.mod h1:oyFgg2WwTEIF0rP6ppjiixOHa5ebK1p8OaRiFhvICBQ=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg=
github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@ -194,6 +194,8 @@ github.com/shirou/gopsutil/v4 v4.25.3 h1:SeA68lsu8gLggyMbmCn8cmp97V1TI9ld9sVzAUc
github.com/shirou/gopsutil/v4 v4.25.3/go.mod h1:xbuxyoZj+UsgnZrENu3lQivsngRR5BdjbJwf2fv4szA=
github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af h1:Sp5TG9f7K39yfB+If0vjp97vuT74F72r8hfRpP8jLU0=
github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA=
github.com/spf13/afero v1.14.0/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@ -234,8 +236,10 @@ go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/arch v0.16.0 h1:foMtLTdyOmIniqWCHjY6+JxuC54XP1fDwx4N0ASyW+U=
golang.org/x/arch v0.16.0/go.mod h1:JmwW7aLIoRUKgaTzhkiEFxvcEiQGyOg9BMonBJUS7EE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View file

@ -60,7 +60,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config)
return
}
ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg)
ep.accessLogger, err = accesslog.NewAccessLogger(parent, cfg)
if err != nil {
return
}

View file

@ -2,59 +2,99 @@ package accesslog
import (
"bufio"
"bytes"
"io"
"net/http"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/synk"
"golang.org/x/time/rate"
)
type (
AccessLogger struct {
task *task.Task
cfg *Config
io AccessLogIO
buffered *bufio.Writer
task *task.Task
cfg *Config
io AccessLogIO
buffered *bufio.Writer
supportRotate bool
lineBufPool *synk.BytesPool // buffer pool for formatting a single log line
errRateLimiter *rate.Limiter
logger zerolog.Logger
lineBufPool sync.Pool // buffer pool for formatting a single log line
Formatter
}
AccessLogIO interface {
io.ReadWriteCloser
io.ReadWriteSeeker
io.ReaderAt
io.Writer
sync.Locker
Name() string // file name or path
Truncate(size int64) error
}
Formatter interface {
// Format writes a log line to line without a trailing newline
Format(line *bytes.Buffer, req *http.Request, res *http.Response)
SetGetTimeNow(getTimeNow func() time.Time)
// AppendLog appends a log line to line with or without a trailing newline
AppendLog(line []byte, req *http.Request, res *http.Response) []byte
}
)
func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
const MinBufferSize = 4 * kilobyte
const (
flushInterval = 30 * time.Second
rotateInterval = time.Hour
)
func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
var ios []AccessLogIO
if cfg.Stdout {
ios = append(ios, stdoutIO)
}
if cfg.Path != "" {
io, err := newFileIO(cfg.Path)
if err != nil {
return nil, err
}
ios = append(ios, io)
}
if len(ios) == 0 {
return nil, nil
}
return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil
}
func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger {
return NewAccessLoggerWithIO(parent, NewMockFile(), cfg)
}
func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
if cfg.BufferSize == 0 {
cfg.BufferSize = DefaultBufferSize
}
if cfg.BufferSize < 4096 {
cfg.BufferSize = 4096
if cfg.BufferSize < MinBufferSize {
cfg.BufferSize = MinBufferSize
}
l := &AccessLogger{
task: parent.Subtask("accesslog"),
cfg: cfg,
io: io,
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
task: parent.Subtask("accesslog."+io.Name(), true),
cfg: cfg,
io: io,
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
lineBufPool: synk.NewBytesPool(1024, synk.DefaultMaxBytes),
errRateLimiter: rate.NewLimiter(rate.Every(time.Second), 1),
logger: logging.With().Str("file", io.Name()).Logger(),
}
fmt := CommonFormatter{cfg: &l.cfg.Fields, GetTimeNow: time.Now}
fmt := CommonFormatter{cfg: &l.cfg.Fields}
switch l.cfg.Format {
case FormatCommon:
l.Formatter = &fmt
@ -66,14 +106,19 @@ func NewAccessLogger(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLog
panic("invalid access log format")
}
l.lineBufPool.New = func() any {
return bytes.NewBuffer(make([]byte, 0, 1024))
if _, ok := l.io.(supportRotate); ok {
l.supportRotate = true
}
go l.start()
return l
}
func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool {
func (l *AccessLogger) Config() *Config {
return l.cfg
}
func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool {
if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) ||
!l.cfg.Filters.Method.CheckKeep(req, res) ||
!l.cfg.Filters.Headers.CheckKeep(req, res) ||
@ -84,53 +129,63 @@ func (l *AccessLogger) checkKeep(req *http.Request, res *http.Response) bool {
}
func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
if !l.checkKeep(req, res) {
if !l.shouldLog(req, res) {
return
}
line := l.lineBufPool.Get().(*bytes.Buffer)
line.Reset()
line := l.lineBufPool.Get()
defer l.lineBufPool.Put(line)
l.Formatter.Format(line, req, res)
line.WriteRune('\n')
l.write(line.Bytes())
line = l.Formatter.AppendLog(line, req, res)
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
l.lockWrite(line)
}
func (l *AccessLogger) LogError(req *http.Request, err error) {
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
}
func (l *AccessLogger) Config() *Config {
return l.cfg
func (l *AccessLogger) ShouldRotate() bool {
return l.cfg.Retention.IsValid() && l.supportRotate
}
func (l *AccessLogger) Rotate() error {
if l.cfg.Retention == nil {
return nil
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
if !l.ShouldRotate() {
return nil, nil
}
l.io.Lock()
defer l.io.Unlock()
return l.rotate()
return rotateLogFile(l.io.(supportRotate), l.cfg.Retention)
}
func (l *AccessLogger) handleErr(err error) {
gperr.LogError("failed to write access log", err)
if l.errRateLimiter.Allow() {
gperr.LogError("failed to write access log", err)
} else {
gperr.LogError("too many errors, stopping access log", err)
l.task.Finish(err)
}
}
func (l *AccessLogger) start() {
defer func() {
defer l.task.Finish(nil)
defer l.close()
if err := l.Flush(); err != nil {
l.handleErr(err)
}
l.close()
l.task.Finish(nil)
}()
// flushes the buffer every 30 seconds
flushTicker := time.NewTicker(30 * time.Second)
defer flushTicker.Stop()
rotateTicker := time.NewTicker(rotateInterval)
defer rotateTicker.Stop()
for {
select {
case <-l.task.Context().Done():
@ -139,6 +194,18 @@ func (l *AccessLogger) start() {
if err := l.Flush(); err != nil {
l.handleErr(err)
}
case <-rotateTicker.C:
if !l.ShouldRotate() {
continue
}
l.logger.Info().Msg("rotating access log file")
if res, err := l.Rotate(); err != nil {
l.handleErr(err)
} else if res != nil {
res.Print(&l.logger)
} else {
l.logger.Info().Msg("no rotation needed")
}
}
}
}
@ -150,18 +217,20 @@ func (l *AccessLogger) Flush() error {
}
func (l *AccessLogger) close() {
l.io.Lock()
defer l.io.Unlock()
l.io.Close()
if r, ok := l.io.(io.Closer); ok {
l.io.Lock()
defer l.io.Unlock()
r.Close()
}
}
func (l *AccessLogger) write(data []byte) {
func (l *AccessLogger) lockWrite(data []byte) {
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
_, err := l.buffered.Write(data)
l.io.Unlock()
if err != nil {
l.handleErr(err)
} else {
logging.Debug().Msg("access log flushed to " + l.io.Name())
logging.Trace().Msg("access log flushed to " + l.io.Name())
}
}

View file

@ -1,7 +1,6 @@
package accesslog_test
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
@ -11,7 +10,7 @@ import (
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
@ -22,14 +21,14 @@ const (
referer = "https://www.google.com/"
proto = "HTTP/1.1"
ua = "Go-http-client/1.1"
status = http.StatusOK
status = http.StatusNotFound
contentLength = 100
method = http.MethodGet
)
var (
testTask = task.RootTask("test", false)
testURL = Must(url.Parse("http://" + host + uri))
testURL = expect.Must(url.Parse("http://" + host + uri))
req = &http.Request{
RemoteAddr: remote,
Method: method,
@ -53,22 +52,20 @@ var (
)
func fmtLog(cfg *Config) (ts string, line string) {
var buf bytes.Buffer
buf := make([]byte, 0, 1024)
t := time.Now()
logger := NewAccessLogger(testTask, nil, cfg)
logger.Formatter.SetGetTimeNow(func() time.Time {
return t
})
logger.Format(&buf, req, resp)
return t.Format(LogTimeFormat), buf.String()
logger := NewMockAccessLogger(testTask, cfg)
MockTimeNow(t)
buf = logger.AppendLog(buf, req, resp)
return t.Format(LogTimeFormat), string(buf)
}
func TestAccessLoggerCommon(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCommon
ts, log := fmtLog(config)
ExpectEqual(t, log,
expect.Equal(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uri, proto, status, contentLength,
),
@ -79,7 +76,7 @@ func TestAccessLoggerCombined(t *testing.T) {
config := DefaultConfig()
config.Format = FormatCombined
ts, log := fmtLog(config)
ExpectEqual(t, log,
expect.Equal(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d \"%s\" \"%s\"",
host, remote, ts, method, uri, proto, status, contentLength, referer, ua,
),
@ -91,37 +88,79 @@ func TestAccessLoggerRedactQuery(t *testing.T) {
config.Format = FormatCommon
config.Fields.Query.Default = FieldModeRedact
ts, log := fmtLog(config)
ExpectEqual(t, log,
expect.Equal(t, log,
fmt.Sprintf("%s %s - - [%s] \"%s %s %s\" %d %d",
host, remote, ts, method, uriRedacted, proto, status, contentLength,
),
)
}
type JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
Method string `json:"method"`
Scheme string `json:"scheme"`
Host string `json:"host"`
Path string `json:"path"`
Protocol string `json:"protocol"`
Status int `json:"status"`
Error string `json:"error,omitempty"`
ContentType string `json:"type"`
Size int64 `json:"size"`
Referer string `json:"referer"`
UserAgent string `json:"useragent"`
Query map[string][]string `json:"query,omitempty"`
Headers map[string][]string `json:"headers,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
}
func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
t.Helper()
config.Format = FormatJSON
var entry JSONLogEntry
_, log := fmtLog(config)
err := json.Unmarshal([]byte(log), &entry)
ExpectNoError(t, err)
expect.NoError(t, err)
return entry
}
func TestAccessLoggerJSON(t *testing.T) {
config := DefaultConfig()
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.IP, remote)
ExpectEqual(t, entry.Method, method)
ExpectEqual(t, entry.Scheme, "http")
ExpectEqual(t, entry.Host, testURL.Host)
ExpectEqual(t, entry.URI, testURL.RequestURI())
ExpectEqual(t, entry.Protocol, proto)
ExpectEqual(t, entry.Status, status)
ExpectEqual(t, entry.ContentType, "text/plain")
ExpectEqual(t, entry.Size, contentLength)
ExpectEqual(t, entry.Referer, referer)
ExpectEqual(t, entry.UserAgent, ua)
ExpectEqual(t, len(entry.Headers), 0)
ExpectEqual(t, len(entry.Cookies), 0)
expect.Equal(t, entry.IP, remote)
expect.Equal(t, entry.Method, method)
expect.Equal(t, entry.Scheme, "http")
expect.Equal(t, entry.Host, testURL.Host)
expect.Equal(t, entry.Path, testURL.Path)
expect.Equal(t, entry.Protocol, proto)
expect.Equal(t, entry.Status, status)
expect.Equal(t, entry.ContentType, "text/plain")
expect.Equal(t, entry.Size, contentLength)
expect.Equal(t, entry.Referer, referer)
expect.Equal(t, entry.UserAgent, ua)
expect.Equal(t, len(entry.Headers), 0)
expect.Equal(t, len(entry.Cookies), 0)
if status >= 400 {
expect.Equal(t, entry.Error, http.StatusText(status))
}
}
func BenchmarkAccessLoggerJSON(b *testing.B) {
config := DefaultConfig()
config.Format = FormatJSON
logger := NewMockAccessLogger(testTask, config)
b.ResetTimer()
for b.Loop() {
logger.Log(req, resp)
}
}
func BenchmarkAccessLoggerCombined(b *testing.B) {
config := DefaultConfig()
config.Format = FormatCombined
logger := NewMockAccessLogger(testTask, config)
b.ResetTimer()
for b.Loop() {
logger.Log(req, resp)
}
}

View file

@ -2,32 +2,40 @@ package accesslog
import (
"bytes"
"errors"
"io"
)
// BackScanner provides an interface to read a file backward line by line.
type BackScanner struct {
file AccessLogIO
chunkSize int
offset int64
buffer []byte
line []byte
err error
file supportRotate
size int64
chunkSize int
chunkBuf []byte
offset int64
chunk []byte
line []byte
err error
}
// NewBackScanner creates a new Scanner to read the file backward.
// chunkSize determines the size of each read chunk from the end of the file.
func NewBackScanner(file AccessLogIO, chunkSize int) *BackScanner {
func NewBackScanner(file supportRotate, chunkSize int) *BackScanner {
size, err := file.Seek(0, io.SeekEnd)
if err != nil {
return &BackScanner{err: err}
}
return newBackScanner(file, size, make([]byte, chunkSize))
}
func newBackScanner(file supportRotate, fileSize int64, buf []byte) *BackScanner {
return &BackScanner{
file: file,
chunkSize: chunkSize,
offset: size,
size: size,
size: fileSize,
offset: fileSize,
chunkSize: len(buf),
chunkBuf: buf,
}
}
@ -41,9 +49,9 @@ func (s *BackScanner) Scan() bool {
// Read chunks until a newline is found or the file is fully read
for {
// Check if there's a line in the buffer
if idx := bytes.LastIndexByte(s.buffer, '\n'); idx >= 0 {
s.line = s.buffer[idx+1:]
s.buffer = s.buffer[:idx]
if idx := bytes.LastIndexByte(s.chunk, '\n'); idx >= 0 {
s.line = s.chunk[idx+1:]
s.chunk = s.chunk[:idx]
if len(s.line) > 0 {
return true
}
@ -53,9 +61,9 @@ func (s *BackScanner) Scan() bool {
for {
if s.offset <= 0 {
// No more data to read; check remaining buffer
if len(s.buffer) > 0 {
s.line = s.buffer
s.buffer = nil
if len(s.chunk) > 0 {
s.line = s.chunk
s.chunk = nil
return true
}
return false
@ -63,22 +71,27 @@ func (s *BackScanner) Scan() bool {
newOffset := max(0, s.offset-int64(s.chunkSize))
chunkSize := s.offset - newOffset
chunk := make([]byte, chunkSize)
chunk := s.chunkBuf[:chunkSize]
n, err := s.file.ReadAt(chunk, newOffset)
if err != nil && err != io.EOF {
s.err = err
if err != nil {
if !errors.Is(err, io.EOF) {
s.err = err
}
return false
} else if n == 0 {
return false
}
// Prepend the chunk to the buffer
s.buffer = append(chunk[:n], s.buffer...)
clone := append([]byte{}, chunk[:n]...)
s.chunk = append(clone, s.chunk...)
s.offset = newOffset
// Check for newline in the updated buffer
if idx := bytes.LastIndexByte(s.buffer, '\n'); idx >= 0 {
s.line = s.buffer[idx+1:]
s.buffer = s.buffer[:idx]
if idx := bytes.LastIndexByte(s.chunk, '\n'); idx >= 0 {
s.line = s.chunk[idx+1:]
s.chunk = s.chunk[:idx]
if len(s.line) > 0 {
return true
}
@ -102,3 +115,12 @@ func (s *BackScanner) FileSize() int64 {
func (s *BackScanner) Err() error {
return s.err
}
func (s *BackScanner) Reset() error {
_, err := s.file.Seek(0, io.SeekStart)
if err != nil {
return err
}
*s = *newBackScanner(s.file, s.size, s.chunkBuf)
return nil
}

View file

@ -2,8 +2,16 @@ package accesslog
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/spf13/afero"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBackScanner(t *testing.T) {
@ -52,7 +60,7 @@ func TestBackScanner(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mock file
mockFile := &MockFile{}
mockFile := NewMockFile()
_, err := mockFile.Write([]byte(tt.input))
if err != nil {
t.Fatalf("failed to write to mock file: %v", err)
@ -94,7 +102,7 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) {
for _, chunkSize := range chunkSizes {
t.Run(fmt.Sprintf("chunk_size_%d", chunkSize), func(t *testing.T) {
mockFile := &MockFile{}
mockFile := NewMockFile()
_, err := mockFile.Write([]byte(input))
if err != nil {
t.Fatalf("failed to write to mock file: %v", err)
@ -125,3 +133,136 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) {
})
}
}
func logEntry() []byte {
accesslog := NewMockAccessLogger(task.RootTask("test", false), &Config{
Format: FormatJSON,
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
srv.URL = "http://localhost:8080"
defer srv.Close()
// make a request to the server
req, _ := http.NewRequest("GET", srv.URL, nil)
res := httptest.NewRecorder()
// server the request
srv.Config.Handler.ServeHTTP(res, req)
b := accesslog.AppendLog(nil, req, res.Result())
if b[len(b)-1] != '\n' {
b = append(b, '\n')
}
return b
}
func TestReset(t *testing.T) {
file, err := afero.TempFile(afero.NewOsFs(), "", "accesslog")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer os.Remove(file.Name())
line := logEntry()
nLines := 1000
for range nLines {
_, err := file.Write(line)
if err != nil {
t.Fatalf("failed to write to temp file: %v", err)
}
}
linesRead := 0
s := NewBackScanner(file, defaultChunkSize)
for s.Scan() {
linesRead++
}
if err := s.Err(); err != nil {
t.Errorf("scanner error: %v", err)
}
expect.Equal(t, linesRead, nLines)
s.Reset()
linesRead = 0
for s.Scan() {
linesRead++
}
if err := s.Err(); err != nil {
t.Errorf("scanner error: %v", err)
}
expect.Equal(t, linesRead, nLines)
}
// 100000 log entries
func BenchmarkBackScanner(b *testing.B) {
mockFile := NewMockFile()
line := logEntry()
for range 100000 {
_, _ = mockFile.Write(line)
}
for i := range 14 {
chunkSize := (2 << i) * kilobyte
scanner := NewBackScanner(mockFile, chunkSize)
name := strutils.FormatByteSize(chunkSize)
b.ResetTimer()
b.Run(name, func(b *testing.B) {
for b.Loop() {
_ = scanner.Reset()
for scanner.Scan() {
}
}
})
}
}
func BenchmarkBackScannerRealFile(b *testing.B) {
file, err := afero.TempFile(afero.NewOsFs(), "", "accesslog")
if err != nil {
b.Fatalf("failed to create temp file: %v", err)
}
defer os.Remove(file.Name())
for range 10000 {
_, err = file.Write(logEntry())
if err != nil {
b.Fatalf("failed to write to temp file: %v", err)
}
}
scanner := NewBackScanner(file, 256*kilobyte)
b.ResetTimer()
for scanner.Scan() {
}
if err := scanner.Err(); err != nil {
b.Errorf("scanner error: %v", err)
}
}
/*
BenchmarkBackScanner
BenchmarkBackScanner/2_KiB
BenchmarkBackScanner/2_KiB-20 52 23254071 ns/op 67596663 B/op 26420 allocs/op
BenchmarkBackScanner/4_KiB
BenchmarkBackScanner/4_KiB-20 55 20961059 ns/op 62529378 B/op 13211 allocs/op
BenchmarkBackScanner/8_KiB
BenchmarkBackScanner/8_KiB-20 64 18242460 ns/op 62951141 B/op 6608 allocs/op
BenchmarkBackScanner/16_KiB
BenchmarkBackScanner/16_KiB-20 52 20162076 ns/op 62940256 B/op 3306 allocs/op
BenchmarkBackScanner/32_KiB
BenchmarkBackScanner/32_KiB-20 54 19247968 ns/op 67553645 B/op 1656 allocs/op
BenchmarkBackScanner/64_KiB
BenchmarkBackScanner/64_KiB-20 60 20909046 ns/op 64053342 B/op 827 allocs/op
BenchmarkBackScanner/128_KiB
BenchmarkBackScanner/128_KiB-20 68 17759890 ns/op 62201945 B/op 414 allocs/op
BenchmarkBackScanner/256_KiB
BenchmarkBackScanner/256_KiB-20 52 19531877 ns/op 61030487 B/op 208 allocs/op
BenchmarkBackScanner/512_KiB
BenchmarkBackScanner/512_KiB-20 54 19124656 ns/op 61030485 B/op 208 allocs/op
BenchmarkBackScanner/1_MiB
BenchmarkBackScanner/1_MiB-20 67 17078936 ns/op 61030495 B/op 208 allocs/op
BenchmarkBackScanner/2_MiB
BenchmarkBackScanner/2_MiB-20 66 18467421 ns/op 61030492 B/op 208 allocs/op
BenchmarkBackScanner/4_MiB
BenchmarkBackScanner/4_MiB-20 68 17214573 ns/op 61030486 B/op 208 allocs/op
BenchmarkBackScanner/8_MiB
BenchmarkBackScanner/8_MiB-20 57 18235229 ns/op 61030492 B/op 208 allocs/op
BenchmarkBackScanner/16_MiB
BenchmarkBackScanner/16_MiB-20 57 19343441 ns/op 61030499 B/op 208 allocs/op
*/

View file

@ -1,6 +1,9 @@
package accesslog
import "github.com/yusing/go-proxy/internal/utils"
import (
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Format string
@ -19,7 +22,8 @@ type (
Config struct {
BufferSize int `json:"buffer_size"`
Format Format `json:"format" validate:"oneof=common combined json"`
Path string `json:"path" validate:"required"`
Path string `json:"path"`
Stdout bool `json:"stdout"`
Filters Filters `json:"filters"`
Fields Fields `json:"fields"`
Retention *Retention `json:"retention"`
@ -30,14 +34,24 @@ var (
FormatCommon Format = "common"
FormatCombined Format = "combined"
FormatJSON Format = "json"
AvailableFormats = []Format{FormatCommon, FormatCombined, FormatJSON}
)
const DefaultBufferSize = 64 * 1024 // 64KB
const DefaultBufferSize = 64 * kilobyte // 64KB
func (cfg *Config) Validate() gperr.Error {
if cfg.Path == "" && !cfg.Stdout {
return gperr.New("path or stdout is required")
}
return nil
}
func DefaultConfig() *Config {
return &Config{
BufferSize: DefaultBufferSize,
Format: FormatCombined,
Retention: &Retention{Days: 30},
Fields: Fields{
Headers: FieldConfig{
Default: FieldModeDrop,

View file

@ -6,7 +6,7 @@ import (
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewConfig(t *testing.T) {
@ -27,27 +27,27 @@ func TestNewConfig(t *testing.T) {
"proxy.fields.cookies.config.foo": "keep",
}
parsed, err := docker.ParseLabels(labels)
ExpectNoError(t, err)
expect.NoError(t, err)
var config Config
err = utils.Deserialize(parsed, &config)
ExpectNoError(t, err)
expect.NoError(t, err)
ExpectEqual(t, config.BufferSize, 10)
ExpectEqual(t, config.Format, FormatCombined)
ExpectEqual(t, config.Path, "/tmp/access.log")
ExpectEqual(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
ExpectEqual(t, len(config.Filters.Method.Values), 2)
ExpectEqual(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
ExpectEqual(t, len(config.Filters.Headers.Values), 2)
ExpectEqual(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
ExpectTrue(t, config.Filters.Headers.Negative)
ExpectEqual(t, len(config.Filters.CIDR.Values), 1)
ExpectEqual(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
ExpectEqual(t, config.Fields.Headers.Default, FieldModeKeep)
ExpectEqual(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
ExpectEqual(t, config.Fields.Query.Default, FieldModeDrop)
ExpectEqual(t, config.Fields.Query.Config["foo"], FieldModeKeep)
ExpectEqual(t, config.Fields.Cookies.Default, FieldModeRedact)
ExpectEqual(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
expect.Equal(t, config.BufferSize, 10)
expect.Equal(t, config.Format, FormatCombined)
expect.Equal(t, config.Path, "/tmp/access.log")
expect.Equal(t, config.Filters.StatusCodes.Values, []*StatusCodeRange{{Start: 200, End: 299}})
expect.Equal(t, len(config.Filters.Method.Values), 2)
expect.Equal(t, config.Filters.Method.Values, []HTTPMethod{"GET", "POST"})
expect.Equal(t, len(config.Filters.Headers.Values), 2)
expect.Equal(t, config.Filters.Headers.Values, []*HTTPHeader{{Key: "foo", Value: "bar"}, {Key: "baz", Value: ""}})
expect.True(t, config.Filters.Headers.Negative)
expect.Equal(t, len(config.Filters.CIDR.Values), 1)
expect.Equal(t, config.Filters.CIDR.Values[0].String(), "192.168.10.0/24")
expect.Equal(t, config.Fields.Headers.Default, FieldModeKeep)
expect.Equal(t, config.Fields.Headers.Config["foo"], FieldModeRedact)
expect.Equal(t, config.Fields.Query.Default, FieldModeDrop)
expect.Equal(t, config.Fields.Query.Config["foo"], FieldModeKeep)
expect.Equal(t, config.Fields.Cookies.Default, FieldModeRedact)
expect.Equal(t, config.Fields.Cookies.Config["foo"], FieldModeKeep)
}

View file

@ -1,8 +1,11 @@
package accesslog
import (
"iter"
"net/http"
"net/url"
"github.com/rs/zerolog"
)
type (
@ -21,83 +24,181 @@ const (
RedactedValue = "REDACTED"
)
func processMap[V any](cfg *FieldConfig, m map[string]V, redactedV V) map[string]V {
type mapStringStringIter interface {
Iter(yield func(k string, v []string) bool)
MarshalZerologObject(e *zerolog.Event)
}
type mapStringStringSlice struct {
m map[string][]string
}
func (m mapStringStringSlice) Iter(yield func(k string, v []string) bool) {
for k, v := range m.m {
if !yield(k, v) {
return
}
}
}
func (m mapStringStringSlice) MarshalZerologObject(e *zerolog.Event) {
for k, v := range m.m {
e.Strs(k, v)
}
}
type mapStringStringRedacted struct {
m map[string][]string
}
func (m mapStringStringRedacted) Iter(yield func(k string, v []string) bool) {
for k := range m.m {
if !yield(k, []string{RedactedValue}) {
return
}
}
}
func (m mapStringStringRedacted) MarshalZerologObject(e *zerolog.Event) {
for k, v := range m.Iter {
e.Strs(k, v)
}
}
type mapStringStringSliceWithConfig struct {
m map[string][]string
cfg *FieldConfig
}
func (m mapStringStringSliceWithConfig) Iter(yield func(k string, v []string) bool) {
var mode FieldMode
var ok bool
for k, v := range m.m {
if mode, ok = m.cfg.Config[k]; !ok {
mode = m.cfg.Default
}
switch mode {
case FieldModeKeep:
if !yield(k, v) {
return
}
case FieldModeRedact:
if !yield(k, []string{RedactedValue}) {
return
}
}
}
}
func (m mapStringStringSliceWithConfig) MarshalZerologObject(e *zerolog.Event) {
for k, v := range m.Iter {
e.Strs(k, v)
}
}
type mapStringStringDrop struct{}
func (m mapStringStringDrop) Iter(yield func(k string, v []string) bool) {}
func (m mapStringStringDrop) MarshalZerologObject(e *zerolog.Event) {}
var mapStringStringDropIter mapStringStringIter = mapStringStringDrop{}
func mapIter[Map http.Header | url.Values](cfg *FieldConfig, m Map) mapStringStringIter {
if len(cfg.Config) == 0 {
switch cfg.Default {
case FieldModeKeep:
return m
return mapStringStringSlice{m: m}
case FieldModeDrop:
return nil
return mapStringStringDropIter
case FieldModeRedact:
redacted := make(map[string]V)
for k := range m {
redacted[k] = redactedV
}
return redacted
return mapStringStringRedacted{m: m}
}
}
return mapStringStringSliceWithConfig{m: m, cfg: cfg}
}
if len(m) == 0 {
return m
}
type slice[V any] struct {
s []V
getKey func(V) string
getVal func(V) string
cfg *FieldConfig
}
newMap := make(map[string]V, len(m))
for k := range m {
type sliceIter interface {
Iter(yield func(k string, v string) bool)
MarshalZerologObject(e *zerolog.Event)
}
func (s *slice[V]) Iter(yield func(k string, v string) bool) {
for _, v := range s.s {
k := s.getKey(v)
var mode FieldMode
var ok bool
if mode, ok = cfg.Config[k]; !ok {
mode = cfg.Default
if mode, ok = s.cfg.Config[k]; !ok {
mode = s.cfg.Default
}
switch mode {
case FieldModeKeep:
newMap[k] = m[k]
if !yield(k, s.getVal(v)) {
return
}
case FieldModeRedact:
newMap[k] = redactedV
if !yield(k, RedactedValue) {
return
}
}
}
return newMap
}
func processSlice[V any, VReturn any](cfg *FieldConfig, s []V, getKey func(V) string, convert func(V) VReturn, redact func(V) VReturn) map[string]VReturn {
type sliceDrop struct{}
func (s sliceDrop) Iter(yield func(k string, v string) bool) {}
func (s sliceDrop) MarshalZerologObject(e *zerolog.Event) {}
var sliceDropIter sliceIter = sliceDrop{}
func (s *slice[V]) MarshalZerologObject(e *zerolog.Event) {
for k, v := range s.Iter {
e.Str(k, v)
}
}
func iterSlice[V any](cfg *FieldConfig, s []V, getKey func(V) string, getVal func(V) string) sliceIter {
if len(s) == 0 ||
len(cfg.Config) == 0 && cfg.Default == FieldModeDrop {
return nil
return sliceDropIter
}
newMap := make(map[string]VReturn, len(s))
for _, v := range s {
var mode FieldMode
var ok bool
k := getKey(v)
if mode, ok = cfg.Config[k]; !ok {
mode = cfg.Default
}
switch mode {
case FieldModeKeep:
newMap[k] = convert(v)
case FieldModeRedact:
newMap[k] = redact(v)
}
}
return newMap
return &slice[V]{s: s, getKey: getKey, getVal: getVal, cfg: cfg}
}
func (cfg *FieldConfig) ProcessHeaders(headers http.Header) http.Header {
return processMap(cfg, headers, []string{RedactedValue})
func (cfg *FieldConfig) IterHeaders(headers http.Header) iter.Seq2[string, []string] {
return mapIter(cfg, headers).Iter
}
func (cfg *FieldConfig) ProcessQuery(q url.Values) url.Values {
return processMap(cfg, q, []string{RedactedValue})
func (cfg *FieldConfig) ZerologHeaders(headers http.Header) zerolog.LogObjectMarshaler {
return mapIter(cfg, headers)
}
func (cfg *FieldConfig) ProcessCookies(cookies []*http.Cookie) map[string]string {
return processSlice(cfg, cookies,
func(c *http.Cookie) string {
return c.Name
},
func(c *http.Cookie) string {
return c.Value
},
func(c *http.Cookie) string {
return RedactedValue
})
func (cfg *FieldConfig) IterQuery(q url.Values) iter.Seq2[string, []string] {
return mapIter(cfg, q).Iter
}
func (cfg *FieldConfig) ZerologQuery(q url.Values) zerolog.LogObjectMarshaler {
return mapIter(cfg, q)
}
func cookieGetKey(c *http.Cookie) string {
return c.Name
}
func cookieGetValue(c *http.Cookie) string {
return c.Value
}
func (cfg *FieldConfig) IterCookies(cookies []*http.Cookie) iter.Seq2[string, string] {
return iterSlice(cfg, cookies, cookieGetKey, cookieGetValue).Iter
}
func (cfg *FieldConfig) ZerologCookies(cookies []*http.Cookie) zerolog.LogObjectMarshaler {
return iterSlice(cfg, cookies, cookieGetKey, cookieGetValue)
}

View file

@ -4,7 +4,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
// Cookie header should be removed,
@ -15,7 +15,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
entry := getJSONEntry(t, config)
for k, v := range req.Header {
if k != "Cookie" {
ExpectEqual(t, entry.Headers[k], v)
expect.Equal(t, entry.Headers[k], v)
}
}
@ -24,8 +24,8 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
"User-Agent": FieldModeDrop,
}
entry = getJSONEntry(t, config)
ExpectEqual(t, entry.Headers["Referer"], []string{RedactedValue})
ExpectEqual(t, entry.Headers["User-Agent"], nil)
expect.Equal(t, entry.Headers["Referer"], []string{RedactedValue})
expect.Equal(t, entry.Headers["User-Agent"], nil)
}
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
@ -33,7 +33,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
config.Fields.Headers.Default = FieldModeDrop
entry := getJSONEntry(t, config)
for k := range req.Header {
ExpectEqual(t, entry.Headers[k], nil)
expect.Equal(t, entry.Headers[k], nil)
}
config.Fields.Headers.Config = map[string]FieldMode{
@ -41,18 +41,17 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
"User-Agent": FieldModeRedact,
}
entry = getJSONEntry(t, config)
ExpectEqual(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
ExpectEqual(t, entry.Headers["User-Agent"], []string{RedactedValue})
expect.Equal(t, entry.Headers["Referer"], []string{req.Header.Get("Referer")})
expect.Equal(t, entry.Headers["User-Agent"], []string{RedactedValue})
}
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
config := DefaultConfig()
config.Fields.Headers.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for k := range req.Header {
if k != "Cookie" {
ExpectEqual(t, entry.Headers[k], []string{RedactedValue})
expect.Equal(t, entry.Headers[k], []string{RedactedValue})
}
}
}
@ -62,9 +61,8 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) {
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeKeep
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], cookie.Value)
expect.Equal(t, entry.Cookies[cookie.Name], cookie.Value)
}
}
@ -73,9 +71,8 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) {
config.Fields.Headers.Default = FieldModeKeep
config.Fields.Cookies.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, len(entry.Headers["Cookie"]), 0)
for _, cookie := range req.Cookies() {
ExpectEqual(t, entry.Cookies[cookie.Name], RedactedValue)
expect.Equal(t, entry.Cookies[cookie.Name], RedactedValue)
}
}
@ -83,14 +80,14 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeDrop
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.Query["foo"], nil)
ExpectEqual(t, entry.Query["bar"], nil)
expect.Equal(t, entry.Query["foo"], nil)
expect.Equal(t, entry.Query["bar"], nil)
}
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
config := DefaultConfig()
config.Fields.Query.Default = FieldModeRedact
entry := getJSONEntry(t, config)
ExpectEqual(t, entry.Query["foo"], []string{RedactedValue})
ExpectEqual(t, entry.Query["bar"], []string{RedactedValue})
expect.Equal(t, entry.Query["foo"], []string{RedactedValue})
expect.Equal(t, entry.Query["bar"], []string{RedactedValue})
}

View file

@ -3,11 +3,10 @@ package accesslog
import (
"fmt"
"os"
"path"
pathPkg "path"
"sync"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
@ -27,16 +26,16 @@ var (
openedFilesMu sync.Mutex
)
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
func newFileIO(path string) (AccessLogIO, error) {
openedFilesMu.Lock()
var file *File
path := path.Clean(cfg.Path)
path = pathPkg.Clean(path)
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
file = opened
} else {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
openedFilesMu.Unlock()
return nil, fmt.Errorf("access log open error: %w", err)
@ -47,7 +46,7 @@ func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error)
}
openedFilesMu.Unlock()
return NewAccessLogger(parent, file, cfg), nil
return file, nil
}
func (f *File) Close() error {

View file

@ -6,7 +6,7 @@ import (
"sync"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/task"
)
@ -16,26 +16,25 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
cfg := DefaultConfig()
cfg.Path = "test.log"
parent := task.RootTask("test", false)
loggerCount := 10
accessLogIOs := make([]AccessLogIO, loggerCount)
// make test log file
file, err := os.Create(cfg.Path)
ExpectNoError(t, err)
expect.NoError(t, err)
file.Close()
t.Cleanup(func() {
ExpectNoError(t, os.Remove(cfg.Path))
expect.NoError(t, os.Remove(cfg.Path))
})
for i := range loggerCount {
wg.Add(1)
go func(index int) {
defer wg.Done()
logger, err := NewFileAccessLogger(parent, cfg)
ExpectNoError(t, err)
accessLogIOs[index] = logger.io
file, err := newFileIO(cfg.Path)
expect.NoError(t, err)
accessLogIOs[index] = file
}(i)
}
@ -43,12 +42,12 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
firstIO := accessLogIOs[0]
for _, io := range accessLogIOs {
ExpectEqual(t, io, firstIO)
expect.Equal(t, io, firstIO)
}
}
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
var file MockFile
file := NewMockFile()
cfg := DefaultConfig()
cfg.BufferSize = 1024
@ -59,15 +58,15 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
loggers := make([]*AccessLogger, loggerCount)
for i := range loggerCount {
loggers[i] = NewAccessLogger(parent, &file, cfg)
loggers[i] = NewAccessLoggerWithIO(parent, file, cfg)
}
var wg sync.WaitGroup
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
resp := &http.Response{StatusCode: http.StatusOK}
wg.Add(len(loggers))
for _, logger := range loggers {
wg.Add(1)
go func(l *AccessLogger) {
defer wg.Done()
parallelLog(l, req, resp, logCountPerLogger)
@ -78,8 +77,8 @@ func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
wg.Wait()
expected := loggerCount * logCountPerLogger
actual := file.LineCount()
ExpectEqual(t, actual, expected)
actual := file.NumLines()
expect.Equal(t, actual, expected)
}
func parallelLog(logger *AccessLogger, req *http.Request, resp *http.Response, n int) {

View file

@ -6,7 +6,7 @@ import (
"strings"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
gpnet "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -24,7 +24,9 @@ type (
Key, Value string
}
Host string
CIDR struct{ types.CIDR }
CIDR struct {
gpnet.CIDR
}
)
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
@ -86,7 +88,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h)
}
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
ip = req.RemoteAddr

View file

@ -1,12 +1,14 @@
package accesslog_test
import (
"net"
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
gpnet "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestStatusCodeFilter(t *testing.T) {
@ -15,20 +17,20 @@ func TestStatusCodeFilter(t *testing.T) {
}
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{}
ExpectTrue(t, filter.CheckKeep(nil, nil))
expect.True(t, filter.CheckKeep(nil, nil))
// keep any 2xx 3xx (inclusive)
filter.Values = values
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
@ -37,20 +39,20 @@ func TestStatusCodeFilter(t *testing.T) {
filter := &LogFilter[*StatusCodeRange]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(nil, nil))
expect.False(t, filter.CheckKeep(nil, nil))
// drop any 2xx 3xx (inclusive)
filter.Values = values
ExpectTrue(t, filter.CheckKeep(nil, &http.Response{
expect.True(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusForbidden,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusOK,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusMultipleChoices,
}))
ExpectFalse(t, filter.CheckKeep(nil, &http.Response{
expect.False(t, filter.CheckKeep(nil, &http.Response{
StatusCode: http.StatusPermanentRedirect,
}))
})
@ -59,19 +61,19 @@ func TestStatusCodeFilter(t *testing.T) {
func TestMethodFilter(t *testing.T) {
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[HTTPMethod]{}
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// keep get only
filter.Values = []HTTPMethod{http.MethodGet}
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
})
@ -80,19 +82,19 @@ func TestMethodFilter(t *testing.T) {
filter := &LogFilter[HTTPMethod]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
// drop post only
filter.Values = []HTTPMethod{http.MethodPost}
ExpectFalse(t, filter.CheckKeep(&http.Request{
expect.False(t, filter.CheckKeep(&http.Request{
Method: http.MethodPost,
}, nil))
ExpectTrue(t, filter.CheckKeep(&http.Request{
expect.True(t, filter.CheckKeep(&http.Request{
Method: http.MethodGet,
}, nil))
})
@ -112,53 +114,54 @@ func TestHeaderFilter(t *testing.T) {
headerFoo := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo"),
}
ExpectEqual(t, headerFoo[0].Key, "Foo")
ExpectEqual(t, headerFoo[0].Value, "")
expect.Equal(t, headerFoo[0].Key, "Foo")
expect.Equal(t, headerFoo[0].Value, "")
headerFooBar := []*HTTPHeader{
strutils.MustParse[*HTTPHeader]("Foo=bar"),
}
ExpectEqual(t, headerFooBar[0].Key, "Foo")
ExpectEqual(t, headerFooBar[0].Value, "bar")
expect.Equal(t, headerFooBar[0].Key, "Foo")
expect.Equal(t, headerFooBar[0].Value, "bar")
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{}
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
expect.True(t, filter.CheckKeep(fooBar, nil))
expect.True(t, filter.CheckKeep(fooBaz, nil))
// keep any foo
filter.Values = headerFoo
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
expect.True(t, filter.CheckKeep(fooBar, nil))
expect.True(t, filter.CheckKeep(fooBaz, nil))
// keep foo == bar
filter.Values = headerFooBar
ExpectTrue(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
expect.True(t, filter.CheckKeep(fooBar, nil))
expect.False(t, filter.CheckKeep(fooBaz, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*HTTPHeader]{
Negative: true,
}
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
expect.False(t, filter.CheckKeep(fooBar, nil))
expect.False(t, filter.CheckKeep(fooBaz, nil))
// drop any foo
filter.Values = headerFoo
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectFalse(t, filter.CheckKeep(fooBaz, nil))
expect.False(t, filter.CheckKeep(fooBar, nil))
expect.False(t, filter.CheckKeep(fooBaz, nil))
// drop foo == bar
filter.Values = headerFooBar
ExpectFalse(t, filter.CheckKeep(fooBar, nil))
ExpectTrue(t, filter.CheckKeep(fooBaz, nil))
expect.False(t, filter.CheckKeep(fooBar, nil))
expect.True(t, filter.CheckKeep(fooBaz, nil))
})
}
func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{
strutils.MustParse[*CIDR]("192.168.10.0/24"),
}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
cidr := []*CIDR{{gpnet.CIDR{
IP: net.ParseIP("192.168.10.0"),
Mask: net.CIDRMask(24, 32),
}}}
expect.Equal(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{
RemoteAddr: "192.168.10.1",
}
@ -168,21 +171,21 @@ func TestCIDRFilter(t *testing.T) {
t.Run("positive", func(t *testing.T) {
filter := &LogFilter[*CIDR]{}
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
expect.True(t, filter.CheckKeep(inCIDR, nil))
expect.True(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectTrue(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
expect.True(t, filter.CheckKeep(inCIDR, nil))
expect.False(t, filter.CheckKeep(notInCIDR, nil))
})
t.Run("negative", func(t *testing.T) {
filter := &LogFilter[*CIDR]{Negative: true}
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectFalse(t, filter.CheckKeep(notInCIDR, nil))
expect.False(t, filter.CheckKeep(inCIDR, nil))
expect.False(t, filter.CheckKeep(notInCIDR, nil))
filter.Values = cidr
ExpectFalse(t, filter.CheckKeep(inCIDR, nil))
ExpectTrue(t, filter.CheckKeep(notInCIDR, nil))
expect.False(t, filter.CheckKeep(inCIDR, nil))
expect.True(t, filter.CheckKeep(notInCIDR, nil))
})
}

View file

@ -2,42 +2,20 @@ package accesslog
import (
"bytes"
"encoding/json"
"iter"
"net"
"net/http"
"net/url"
"strconv"
"time"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog"
)
type (
CommonFormatter struct {
cfg *Fields
GetTimeNow func() time.Time // for testing purposes only
cfg *Fields
}
CombinedFormatter struct{ CommonFormatter }
JSONFormatter struct{ CommonFormatter }
JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
Method string `json:"method"`
Scheme string `json:"scheme"`
Host string `json:"host"`
URI string `json:"uri"`
Protocol string `json:"protocol"`
Status int `json:"status"`
Error string `json:"error,omitempty"`
ContentType string `json:"type"`
Size int64 `json:"size"`
Referer string `json:"referer"`
UserAgent string `json:"useragent"`
Query map[string][]string `json:"query,omitempty"`
Headers map[string][]string `json:"headers,omitempty"`
Cookies map[string]string `json:"cookies,omitempty"`
}
)
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
@ -49,12 +27,24 @@ func scheme(req *http.Request) string {
return "http"
}
func requestURI(u *url.URL, query url.Values) string {
uri := u.EscapedPath()
if len(query) > 0 {
uri += "?" + query.Encode()
func appendRequestURI(line []byte, req *http.Request, query iter.Seq2[string, []string]) []byte {
uri := req.URL.EscapedPath()
line = append(line, uri...)
isFirst := true
for k, v := range query {
if isFirst {
line = append(line, '?')
isFirst = false
} else {
line = append(line, '&')
}
line = append(line, k...)
line = append(line, '=')
for _, v := range v {
line = append(line, v...)
}
}
return uri
return line
}
func clientIP(req *http.Request) string {
@ -65,80 +55,102 @@ func clientIP(req *http.Request) string {
return req.RemoteAddr
}
// debug only.
func (f *CommonFormatter) SetGetTimeNow(getTimeNow func() time.Time) {
f.GetTimeNow = getTimeNow
func (f *CommonFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte {
query := f.cfg.Query.IterQuery(req.URL.Query())
line = append(line, req.Host...)
line = append(line, ' ')
line = append(line, clientIP(req)...)
line = append(line, " - - ["...)
line = TimeNow().AppendFormat(line, LogTimeFormat)
line = append(line, `] "`...)
line = append(line, req.Method...)
line = append(line, ' ')
line = appendRequestURI(line, req, query)
line = append(line, ' ')
line = append(line, req.Proto...)
line = append(line, '"')
line = append(line, ' ')
line = strconv.AppendInt(line, int64(res.StatusCode), 10)
line = append(line, ' ')
line = strconv.AppendInt(line, res.ContentLength, 10)
return line
}
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
line.WriteString(req.Host)
line.WriteRune(' ')
line.WriteString(clientIP(req))
line.WriteString(" - - [")
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
line.WriteString("] \"")
line.WriteString(req.Method)
line.WriteRune(' ')
line.WriteString(requestURI(req.URL, query))
line.WriteRune(' ')
line.WriteString(req.Proto)
line.WriteString("\" ")
line.WriteString(strconv.Itoa(res.StatusCode))
line.WriteRune(' ')
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
func (f *CombinedFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte {
line = f.CommonFormatter.AppendLog(line, req, res)
line = append(line, " \""...)
line = append(line, req.Referer()...)
line = append(line, "\" \""...)
line = append(line, req.UserAgent()...)
line = append(line, '"')
return line
}
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
f.CommonFormatter.Format(line, req, res)
line.WriteString(" \"")
line.WriteString(req.Referer())
line.WriteString("\" \"")
line.WriteString(req.UserAgent())
line.WriteRune('"')
type zeroLogStringStringMapMarshaler struct {
values map[string]string
}
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
headers := f.cfg.Headers.ProcessHeaders(req.Header)
headers.Del("Cookie")
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
entry := JSONLogEntry{
Time: f.GetTimeNow().Format(LogTimeFormat),
IP: clientIP(req),
Method: req.Method,
Scheme: scheme(req),
Host: req.Host,
URI: requestURI(req.URL, query),
Protocol: req.Proto,
Status: res.StatusCode,
ContentType: res.Header.Get("Content-Type"),
Size: res.ContentLength,
Referer: req.Referer(),
UserAgent: req.UserAgent(),
Query: query,
Headers: headers,
Cookies: cookies,
func (z *zeroLogStringStringMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
if len(z.values) == 0 {
return
}
for k, v := range z.values {
e.Str(k, v)
}
}
type zeroLogStringStringSliceMapMarshaler struct {
values map[string][]string
}
func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.Event) {
if len(z.values) == 0 {
return
}
for k, v := range z.values {
e.Strs(k, v)
}
}
func (f *JSONFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte {
query := f.cfg.Query.ZerologQuery(req.URL.Query())
headers := f.cfg.Headers.ZerologHeaders(req.Header)
cookies := f.cfg.Cookies.ZerologCookies(req.Cookies())
contentType := res.Header.Get("Content-Type")
writer := bytes.NewBuffer(line)
logger := zerolog.New(writer).With().Logger()
event := logger.Info().
Str("time", TimeNow().Format(LogTimeFormat)).
Str("ip", clientIP(req)).
Str("method", req.Method).
Str("scheme", scheme(req)).
Str("host", req.Host).
Str("path", req.URL.Path).
Str("protocol", req.Proto).
Int("status", res.StatusCode).
Str("type", contentType).
Int64("size", res.ContentLength).
Str("referer", req.Referer()).
Str("useragent", req.UserAgent()).
Object("query", query).
Object("headers", headers).
Object("cookies", cookies)
if res.StatusCode >= 400 {
entry.Error = res.Status
if res.Status != "" {
event.Str("error", res.Status)
} else {
event.Str("error", http.StatusText(res.StatusCode))
}
}
if entry.ContentType == "" {
// try to get content type from request
entry.ContentType = req.Header.Get("Content-Type")
}
marshaller := json.NewEncoder(line)
err := marshaller.Encode(entry)
if err != nil {
logging.Err(err).Msg("failed to marshal json log")
}
// NOTE: zerolog will append a newline to the buffer
event.Send()
return writer.Bytes()
}

View file

@ -3,75 +3,47 @@ package accesslog
import (
"bytes"
"io"
"sync"
"github.com/spf13/afero"
)
type noLock struct{}
func (noLock) Lock() {}
func (noLock) Unlock() {}
type MockFile struct {
data []byte
position int64
sync.Mutex
afero.File
noLock
}
func (m *MockFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case io.SeekStart:
m.position = offset
case io.SeekCurrent:
m.position += offset
case io.SeekEnd:
m.position = int64(len(m.data)) + offset
func NewMockFile() *MockFile {
f, _ := afero.TempFile(afero.NewMemMapFs(), "", "")
return &MockFile{
File: f,
}
return m.position, nil
}
func (m *MockFile) Write(p []byte) (n int, err error) {
m.data = append(m.data, p...)
n = len(p)
m.position += int64(n)
return
}
func (m *MockFile) Name() string {
return "mock"
}
func (m *MockFile) Read(p []byte) (n int, err error) {
if m.position >= int64(len(m.data)) {
return 0, io.EOF
}
n = copy(p, m.data[m.position:])
m.position += int64(n)
return n, nil
}
func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) {
if off >= int64(len(m.data)) {
return 0, io.EOF
}
n = copy(p, m.data[off:])
return n, nil
}
func (m *MockFile) Close() error {
return nil
}
func (m *MockFile) Truncate(size int64) error {
m.data = m.data[:size]
m.position = size
return nil
}
func (m *MockFile) LineCount() int {
m.Lock()
defer m.Unlock()
return bytes.Count(m.data[:m.position], []byte("\n"))
}
func (m *MockFile) Len() int64 {
return m.position
filesize, _ := m.Seek(0, io.SeekEnd)
_, _ = m.Seek(0, io.SeekStart)
return filesize
}
func (m *MockFile) Content() []byte {
return m.data[:m.position]
buf := bytes.NewBuffer(nil)
m.Seek(0, io.SeekStart)
_, _ = buf.ReadFrom(m.File)
m.Seek(0, io.SeekStart)
return buf.Bytes()
}
func (m *MockFile) NumLines() int {
content := m.Content()
count := bytes.Count(content, []byte("\n"))
// account for last line if it does not end with a newline
if len(content) > 0 && content[len(content)-1] != '\n' {
count++
}
return count
}

View file

@ -0,0 +1,46 @@
package accesslog
import "strings"
type MultiWriter struct {
writers []AccessLogIO
}
func NewMultiWriter(writers ...AccessLogIO) AccessLogIO {
if len(writers) == 0 {
return nil
}
if len(writers) == 1 {
return writers[0]
}
return &MultiWriter{
writers: writers,
}
}
func (w *MultiWriter) Write(p []byte) (n int, err error) {
for _, writer := range w.writers {
writer.Write(p)
}
return len(p), nil
}
func (w *MultiWriter) Lock() {
for _, writer := range w.writers {
writer.Lock()
}
}
func (w *MultiWriter) Unlock() {
for _, writer := range w.writers {
writer.Unlock()
}
}
func (w *MultiWriter) Name() string {
names := make([]string, len(w.writers))
for i, writer := range w.writers {
names[i] = writer.Name()
}
return strings.Join(names, ", ")
}

View file

@ -1,6 +1,7 @@
package accesslog
import (
"fmt"
"strconv"
"github.com/yusing/go-proxy/internal/gperr"
@ -8,8 +9,9 @@ import (
)
type Retention struct {
Days uint64 `json:"days"`
Last uint64 `json:"last"`
Days uint64 `json:"days"`
Last uint64 `json:"last"`
KeepSize uint64 `json:"keep_size"`
}
var (
@ -17,7 +19,8 @@ var (
ErrZeroValue = gperr.New("zero value")
)
var defaultChunkSize = 64 * 1024 // 64KB
// see back_scanner_test.go#L210 for benchmarks
var defaultChunkSize = 256 * kilobyte
// Syntax:
//
@ -25,6 +28,8 @@ var defaultChunkSize = 64 * 1024 // 64KB
//
// last <N>
//
// <N> KB|MB|GB|kb|mb|gb
//
// Parse implements strutils.Parser.
func (r *Retention) Parse(v string) (err error) {
split := strutils.SplitSpace(v)
@ -35,22 +40,55 @@ func (r *Retention) Parse(v string) (err error) {
case "last":
r.Last, err = strconv.ParseUint(split[1], 10, 64)
default: // <N> days|weeks|months
r.Days, err = strconv.ParseUint(split[0], 10, 64)
n, err := strconv.ParseUint(split[0], 10, 64)
if err != nil {
return
return err
}
switch split[1] {
case "days":
case "weeks":
r.Days *= 7
case "months":
r.Days *= 30
case "day", "days":
r.Days = n
case "week", "weeks":
r.Days = n * 7
case "month", "months":
r.Days = n * 30
case "kb", "Kb":
r.KeepSize = n * kilobits
case "KB":
r.KeepSize = n * kilobyte
case "mb", "Mb":
r.KeepSize = n * megabits
case "MB":
r.KeepSize = n * megabyte
case "gb", "Gb":
r.KeepSize = n * gigabits
case "GB":
r.KeepSize = n * gigabyte
default:
return ErrInvalidSyntax.Subject("unit " + split[1])
}
}
if r.Days == 0 && r.Last == 0 {
if !r.IsValid() {
return ErrZeroValue
}
return
}
func (r *Retention) String() string {
if r.Days > 0 {
return fmt.Sprintf("%d days", r.Days)
}
if r.Last > 0 {
return fmt.Sprintf("last %d", r.Last)
}
if r.KeepSize > 0 {
return strutils.FormatByteSize(r.KeepSize)
}
return "<invalid>"
}
func (r *Retention) IsValid() bool {
if r == nil {
return false
}
return r.Days > 0 || r.Last > 0 || r.KeepSize > 0
}

View file

@ -4,7 +4,7 @@ import (
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseRetention(t *testing.T) {
@ -24,9 +24,9 @@ func TestParseRetention(t *testing.T) {
r := &Retention{}
err := r.Parse(test.input)
if !test.shouldErr {
ExpectNoError(t, err)
expect.NoError(t, err)
} else {
ExpectEqual(t, r, test.expected)
expect.Equal(t, r, test.expected)
}
})
}

View file

@ -4,116 +4,252 @@ import (
"bytes"
"io"
"time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/utils/synk"
)
func (l *AccessLogger) rotate() (err error) {
// Get retention configuration
config := l.Config().Retention
var shouldKeep func(t time.Time, lineCount int) bool
type supportRotate interface {
io.Reader
io.Writer
io.Seeker
io.ReaderAt
io.WriterAt
Truncate(size int64) error
}
type RotateResult struct {
Filename string
OriginalSize int64 // original size of the file
NumBytesRead int64 // number of bytes read from the file
NumBytesKeep int64 // number of bytes to keep
NumLinesRead int // number of lines read from the file
NumLinesKeep int // number of lines to keep
NumLinesInvalid int // number of invalid lines
}
func (r *RotateResult) Print(logger *zerolog.Logger) {
logger.Info().
Str("original_size", strutils.FormatByteSize(r.OriginalSize)).
Str("bytes_read", strutils.FormatByteSize(r.NumBytesRead)).
Str("bytes_keep", strutils.FormatByteSize(r.NumBytesKeep)).
Int("lines_read", r.NumLinesRead).
Int("lines_keep", r.NumLinesKeep).
Int("lines_invalid", r.NumLinesInvalid).
Msg("log rotate result")
}
type lineInfo struct {
Pos int64 // Position from the start of the file
Size int64 // Size of this line
}
// do not allocate initial size
var rotateBytePool = synk.NewBytesPool(0, 16*1024*1024)
// rotateLogFile rotates the log file based on the retention policy.
// It returns the result of the rotation and an error if any.
//
// The file is rotated by reading the file backward line-by-line
// and stop once error occurs or found a line that should not be kept.
//
// Any invalid lines will be skipped and not included in the result.
//
// If the file does not need to be rotated, it returns nil, nil.
func rotateLogFile(file supportRotate, config *Retention) (result *RotateResult, err error) {
if config.KeepSize > 0 {
return rotateLogFileBySize(file, config)
}
var shouldStop func() bool
t := TimeNow()
if config.Last > 0 {
shouldKeep = func(_ time.Time, lineCount int) bool {
return lineCount < int(config.Last)
}
shouldStop = func() bool { return result.NumLinesKeep-result.NumLinesInvalid == int(config.Last) }
// not needed to parse time for last N lines
} else if config.Days > 0 {
cutoff := time.Now().AddDate(0, 0, -int(config.Days))
shouldKeep = func(t time.Time, _ int) bool {
return !t.IsZero() && !t.Before(cutoff)
}
cutoff := TimeNow().AddDate(0, 0, -int(config.Days)+1)
shouldStop = func() bool { return t.Before(cutoff) }
} else {
return nil // No retention policy set
return nil, nil // should not happen
}
s := NewBackScanner(l.io, defaultChunkSize)
nRead := 0
nLines := 0
s := NewBackScanner(file, defaultChunkSize)
result = &RotateResult{
OriginalSize: s.FileSize(),
}
// nothing to rotate, return the nothing
if result.OriginalSize == 0 {
return nil, nil
}
// Store the line positions and sizes we want to keep
linesToKeep := make([]lineInfo, 0)
lastLineValid := false
for s.Scan() {
nRead += len(s.Bytes()) + 1
nLines++
t := ParseLogTime(s.Bytes())
if !shouldKeep(t, nLines) {
result.NumLinesRead++
lineSize := int64(len(s.Bytes()) + 1) // +1 for newline
linePos := result.OriginalSize - result.NumBytesRead - lineSize
result.NumBytesRead += lineSize
// Check if line has valid time
t = ParseLogTime(s.Bytes())
if t.IsZero() {
result.NumLinesInvalid++
lastLineValid = false
continue
}
// Check if we should stop based on retention policy
if shouldStop() {
break
}
// Add line to those we want to keep
if lastLineValid {
last := linesToKeep[len(linesToKeep)-1]
linesToKeep[len(linesToKeep)-1] = lineInfo{
Pos: last.Pos - lineSize,
Size: last.Size + lineSize,
}
} else {
linesToKeep = append(linesToKeep, lineInfo{
Pos: linePos,
Size: lineSize,
})
}
result.NumBytesKeep += lineSize
result.NumLinesKeep++
lastLineValid = true
}
if s.Err() != nil {
return s.Err()
return nil, s.Err()
}
beg := int64(nRead)
if _, err := l.io.Seek(-beg, io.SeekEnd); err != nil {
return err
}
buf := make([]byte, nRead)
if _, err := l.io.Read(buf); err != nil {
return err
// nothing to keep, truncate to empty
if len(linesToKeep) == 0 {
return nil, file.Truncate(0)
}
if err := l.writeTruncate(buf); err != nil {
return err
// nothing to rotate, return nothing
if result.NumBytesKeep == result.OriginalSize {
return nil, nil
}
return nil
// Read each line and write it to the beginning of the file
writePos := int64(0)
buf := rotateBytePool.Get()
defer rotateBytePool.Put(buf)
// in reverse order to keep the order of the lines (from old to new)
for i := len(linesToKeep) - 1; i >= 0; i-- {
line := linesToKeep[i]
n := line.Size
if cap(buf) < int(n) {
buf = make([]byte, n)
}
buf = buf[:n]
// Read the line from its original position
if _, err := file.ReadAt(buf, line.Pos); err != nil {
return nil, err
}
// Write it to the new position
if _, err := file.WriteAt(buf, writePos); err != nil {
return nil, err
}
writePos += n
}
if err := file.Truncate(writePos); err != nil {
return nil, err
}
return result, nil
}
func (l *AccessLogger) writeTruncate(buf []byte) (err error) {
// Seek to beginning and truncate
if _, err := l.io.Seek(0, 0); err != nil {
return err
}
// Write buffer back to file
nWritten, err := l.buffered.Write(buf)
// rotateLogFileBySize rotates the log file by size.
// It returns the result of the rotation and an error if any.
//
// The file is not being read, it just truncate the file to the new size.
//
// Invalid lines will not be detected and included in the result.
func rotateLogFileBySize(file supportRotate, config *Retention) (result *RotateResult, err error) {
filesize, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if err = l.buffered.Flush(); err != nil {
return err
return nil, err
}
// Truncate file
if err = l.io.Truncate(int64(nWritten)); err != nil {
return err
result = &RotateResult{
OriginalSize: filesize,
}
// check bytes written == buffer size
if nWritten != len(buf) {
return io.ErrShortWrite
keepSize := int64(config.KeepSize)
if keepSize >= filesize {
result.NumBytesKeep = filesize
return result, nil
}
return
result.NumBytesKeep = keepSize
err = file.Truncate(keepSize)
if err != nil {
return nil, err
}
return result, nil
}
const timeLen = len(`"time":"`)
var timeJSON = []byte(`"time":"`)
// ParseLogTime parses the time from the log line.
// It returns the time if the time is found and valid in the log line,
// otherwise it returns zero time.
func ParseLogTime(line []byte) (t time.Time) {
if len(line) == 0 {
return
}
if i := bytes.Index(line, timeJSON); i != -1 { // JSON format
var jsonStart = i + timeLen
var jsonEnd = i + timeLen + len(LogTimeFormat)
if len(line) < jsonEnd {
return
}
timeStr := line[jsonStart:jsonEnd]
t, _ = time.Parse(LogTimeFormat, string(timeStr))
if timeStr := ExtractTime(line); timeStr != nil {
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
return
}
// Common/Combined format
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
start := bytes.IndexByte(line, '[')
if start == -1 {
return
}
end := bytes.IndexByte(line[start:], ']')
if end == -1 {
return
}
end += start // adjust end position relative to full line
timeStr := line[start+1 : end]
t, _ = time.Parse(LogTimeFormat, string(timeStr)) // ignore error
return
}
var timeJSON = []byte(`"time":"`)
// ExtractTime extracts the time from the log line.
// It returns the time if the time is found,
// otherwise it returns nil.
//
// The returned time is not validated.
func ExtractTime(line []byte) []byte {
//TODO: optimize this
switch line[0] {
case '{': // JSON format
if i := bytes.Index(line, timeJSON); i != -1 {
var jsonStart = i + len(`"time":"`)
var jsonEnd = i + len(`"time":"`) + len(LogTimeFormat)
if len(line) < jsonEnd {
return nil
}
return line[jsonStart:jsonEnd]
}
return nil // invalid JSON line
default:
// Common/Combined format
// Format: <virtual host> <host ip> - - [02/Jan/2006:15:04:05 -0700] ...
start := bytes.IndexByte(line, '[')
if start == -1 {
return nil
}
end := start + 1 + len(LogTimeFormat)
if len(line) < end {
return nil
}
return line[start+1 : end]
}
}

View file

@ -1,6 +1,7 @@
package accesslog_test
import (
"bytes"
"fmt"
"testing"
"time"
@ -8,79 +9,280 @@ import (
. "github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
var (
testTime = expect.Must(time.Parse(time.RFC3339, "2024-01-31T03:04:05Z"))
testTimeStr = testTime.Format(LogTimeFormat)
)
func TestParseLogTime(t *testing.T) {
tests := []string{
`{"foo":"bar","time":"%s","bar":"baz"}`,
`example.com 192.168.1.1 - - [%s] "GET / HTTP/1.1" 200 1234`,
}
testTime := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
testTimeStr := testTime.Format(LogTimeFormat)
t.Run("valid time", func(t *testing.T) {
tests := []string{
`{"foo":"bar","time":"%s","bar":"baz"}`,
`example.com 192.168.1.1 - - [%s] "GET / HTTP/1.1" 200 1234`,
}
for i, test := range tests {
tests[i] = fmt.Sprintf(test, testTimeStr)
}
for i, test := range tests {
tests[i] = fmt.Sprintf(test, testTimeStr)
}
for _, test := range tests {
t.Run(test, func(t *testing.T) {
actual := ParseLogTime([]byte(test))
ExpectTrue(t, actual.Equal(testTime))
for _, test := range tests {
t.Run(test, func(t *testing.T) {
extracted := ExtractTime([]byte(test))
expect.Equal(t, string(extracted), testTimeStr)
got := ParseLogTime([]byte(test))
expect.True(t, got.Equal(testTime), "expected %s, got %s", testTime, got)
})
}
})
t.Run("invalid time", func(t *testing.T) {
tests := []string{
`{"foo":"bar","time":"invalid","bar":"baz"}`,
`example.com 192.168.1.1 - - [invalid] "GET / HTTP/1.1" 200 1234`,
}
for _, test := range tests {
t.Run(test, func(t *testing.T) {
expect.True(t, ParseLogTime([]byte(test)).IsZero(), "expected zero time, got %s", ParseLogTime([]byte(test)))
})
}
})
}
func TestRotateKeepLast(t *testing.T) {
for _, format := range AvailableFormats {
t.Run(string(format)+" keep last", func(t *testing.T) {
file := NewMockFile()
MockTimeNow(testTime)
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: format,
})
expect.Nil(t, logger.Config().Retention)
for range 10 {
logger.Log(req, resp)
}
expect.NoError(t, logger.Flush())
expect.Greater(t, file.Len(), int64(0))
expect.Equal(t, file.NumLines(), 10)
retention := strutils.MustParse[*Retention]("last 5")
expect.Equal(t, retention.Days, 0)
expect.Equal(t, retention.Last, 5)
expect.Equal(t, retention.KeepSize, 0)
logger.Config().Retention = retention
result, err := logger.Rotate()
expect.NoError(t, err)
expect.Equal(t, file.NumLines(), int(retention.Last))
expect.Equal(t, result.NumLinesKeep, int(retention.Last))
expect.Equal(t, result.NumLinesInvalid, 0)
})
t.Run(string(format)+" keep days", func(t *testing.T) {
file := NewMockFile()
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: format,
})
expect.Nil(t, logger.Config().Retention)
nLines := 10
for i := range nLines {
MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1))
logger.Log(req, resp)
}
logger.Flush()
expect.Equal(t, file.NumLines(), nLines)
retention := strutils.MustParse[*Retention]("3 days")
expect.Equal(t, retention.Days, 3)
expect.Equal(t, retention.Last, 0)
expect.Equal(t, retention.KeepSize, 0)
logger.Config().Retention = retention
MockTimeNow(testTime)
result, err := logger.Rotate()
expect.NoError(t, err)
expect.Equal(t, file.NumLines(), int(retention.Days))
expect.Equal(t, result.NumLinesKeep, int(retention.Days))
expect.Equal(t, result.NumLinesInvalid, 0)
rotated := file.Content()
rotatedLines := bytes.Split(rotated, []byte("\n"))
for i, line := range rotatedLines {
if i >= int(retention.Days) { // may ends with a newline
break
}
timeBytes := ExtractTime(line)
got, err := time.Parse(LogTimeFormat, string(timeBytes))
expect.NoError(t, err)
want := testTime.AddDate(0, 0, -int(retention.Days)+i+1)
expect.True(t, got.Equal(want), "expected %s, got %s", want, got)
}
})
}
}
func TestRetentionCommonFormat(t *testing.T) {
var file MockFile
logger := NewAccessLogger(task.RootTask("test", false), &file, &Config{
Format: FormatCommon,
BufferSize: 1024,
})
for range 10 {
logger.Log(req, resp)
}
logger.Flush()
// test.Finish(nil)
ExpectEqual(t, logger.Config().Retention, nil)
ExpectTrue(t, file.Len() > 0)
ExpectEqual(t, file.LineCount(), 10)
t.Run("keep last", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("last 5")
ExpectEqual(t, logger.Config().Retention.Days, 0)
ExpectEqual(t, logger.Config().Retention.Last, 5)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.LineCount(), 5)
})
_ = file.Truncate(0)
timeNow := time.Now()
for i := range 10 {
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
return timeNow.AddDate(0, 0, -10+i)
}
logger.Log(req, resp)
}
logger.Flush()
ExpectEqual(t, file.LineCount(), 10)
t.Run("keep days", func(t *testing.T) {
logger.Config().Retention = strutils.MustParse[*Retention]("3 days")
ExpectEqual(t, logger.Config().Retention.Days, 3)
ExpectEqual(t, logger.Config().Retention.Last, 0)
ExpectNoError(t, logger.Rotate())
ExpectEqual(t, file.LineCount(), 3)
rotated := string(file.Content())
_ = file.Truncate(0)
for i := range 3 {
logger.Formatter.(*CommonFormatter).GetTimeNow = func() time.Time {
return timeNow.AddDate(0, 0, -3+i)
func TestRotateKeepFileSize(t *testing.T) {
for _, format := range AvailableFormats {
t.Run(string(format)+" keep size no rotation", func(t *testing.T) {
file := NewMockFile()
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: format,
})
expect.Nil(t, logger.Config().Retention)
nLines := 10
for i := range nLines {
MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1))
logger.Log(req, resp)
}
logger.Flush()
expect.Equal(t, file.NumLines(), nLines)
retention := strutils.MustParse[*Retention]("100 KB")
expect.Equal(t, retention.KeepSize, 100*1024)
expect.Equal(t, retention.Days, 0)
expect.Equal(t, retention.Last, 0)
logger.Config().Retention = retention
MockTimeNow(testTime)
result, err := logger.Rotate()
expect.NoError(t, err)
// file should be untouched as 100KB > 10 lines * bytes per line
expect.Equal(t, result.NumBytesKeep, file.Len())
expect.Equal(t, result.NumBytesRead, 0, "should not read any bytes")
})
}
t.Run("keep size with rotation", func(t *testing.T) {
file := NewMockFile()
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: FormatJSON,
})
expect.Nil(t, logger.Config().Retention)
nLines := 100
for i := range nLines {
MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1))
logger.Log(req, resp)
}
ExpectEqual(t, rotated, string(file.Content()))
logger.Flush()
expect.Equal(t, file.NumLines(), nLines)
retention := strutils.MustParse[*Retention]("10 KB")
expect.Equal(t, retention.KeepSize, 10*1024)
expect.Equal(t, retention.Days, 0)
expect.Equal(t, retention.Last, 0)
logger.Config().Retention = retention
MockTimeNow(testTime)
result, err := logger.Rotate()
expect.NoError(t, err)
expect.Equal(t, result.NumBytesKeep, int64(retention.KeepSize))
expect.Equal(t, file.Len(), int64(retention.KeepSize))
expect.Equal(t, result.NumBytesRead, 0, "should not read any bytes")
})
}
// skipping invalid lines is not supported for keep file_size
func TestRotateSkipInvalidTime(t *testing.T) {
for _, format := range AvailableFormats {
t.Run(string(format), func(t *testing.T) {
file := NewMockFile()
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: format,
})
expect.Nil(t, logger.Config().Retention)
nLines := 10
for i := range nLines {
MockTimeNow(testTime.AddDate(0, 0, -nLines+i+1))
logger.Log(req, resp)
logger.Flush()
n, err := file.Write([]byte("invalid time\n"))
expect.NoError(t, err)
expect.Equal(t, n, len("invalid time\n"))
}
expect.Equal(t, file.NumLines(), 2*nLines)
retention := strutils.MustParse[*Retention]("3 days")
expect.Equal(t, retention.Days, 3)
expect.Equal(t, retention.Last, 0)
logger.Config().Retention = retention
result, err := logger.Rotate()
expect.NoError(t, err)
// should read one invalid line after every valid line
expect.Equal(t, result.NumLinesKeep, int(retention.Days))
expect.Equal(t, result.NumLinesInvalid, nLines-int(retention.Days)*2)
expect.Equal(t, file.NumLines(), int(retention.Days))
})
}
}
func BenchmarkRotate(b *testing.B) {
tests := []*Retention{
{Days: 30},
{Last: 100},
{KeepSize: 24 * 1024},
}
for _, retention := range tests {
b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) {
file := NewMockFile()
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: FormatJSON,
Retention: retention,
})
for i := range 100 {
MockTimeNow(testTime.AddDate(0, 0, -100+i+1))
logger.Log(req, resp)
}
logger.Flush()
content := file.Content()
b.ResetTimer()
for b.Loop() {
b.StopTimer()
file = NewMockFile()
_, _ = file.Write(content)
b.StartTimer()
_, _ = logger.Rotate()
}
})
}
}
func BenchmarkRotateWithInvalidTime(b *testing.B) {
tests := []*Retention{
{Days: 30},
{Last: 100},
{KeepSize: 24 * 1024},
}
for _, retention := range tests {
b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) {
file := NewMockFile()
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
Format: FormatJSON,
Retention: retention,
})
for i := range 10000 {
MockTimeNow(testTime.AddDate(0, 0, -10000+i+1))
logger.Log(req, resp)
if i%10 == 0 {
_, _ = file.Write([]byte("invalid time\n"))
}
}
logger.Flush()
content := file.Content()
b.ResetTimer()
for b.Loop() {
b.StopTimer()
file = NewMockFile()
_, _ = file.Write(content)
b.StartTimer()
_, _ = logger.Rotate()
}
})
}
}

View file

@ -0,0 +1,18 @@
package accesslog
import (
"io"
"os"
)
type StdoutLogger struct {
io.Writer
}
var stdoutIO = &StdoutLogger{os.Stdout}
func (l *StdoutLogger) Lock() {}
func (l *StdoutLogger) Unlock() {}
func (l *StdoutLogger) Name() string {
return "stdout"
}

View file

@ -0,0 +1,48 @@
package accesslog
import (
"time"
"github.com/yusing/go-proxy/internal/task"
"go.uber.org/atomic"
)
var (
TimeNow = DefaultTimeNow
shouldCallTimeNow atomic.Bool
timeNowTicker = time.NewTicker(shouldCallTimeNowInterval)
lastTimeNow = time.Now()
)
const shouldCallTimeNowInterval = 100 * time.Millisecond
func MockTimeNow(t time.Time) {
TimeNow = func() time.Time {
return t
}
}
// DefaultTimeNow is a time.Now wrapper that reduces the number of calls to time.Now
// by caching the result and only allow calling time.Now when the ticker fires.
//
// Returned value may have +-100ms error.
func DefaultTimeNow() time.Time {
if shouldCallTimeNow.Load() {
lastTimeNow = time.Now()
shouldCallTimeNow.Store(false)
}
return lastTimeNow
}
func init() {
go func() {
for {
select {
case <-task.RootContext().Done():
return
case <-timeNowTicker.C:
shouldCallTimeNow.Store(true)
}
}
}()
}

View file

@ -0,0 +1,102 @@
package accesslog
import (
"testing"
"time"
)
func BenchmarkTimeNow(b *testing.B) {
b.Run("default", func(b *testing.B) {
for b.Loop() {
time.Now()
}
})
b.Run("reduced_call", func(b *testing.B) {
for b.Loop() {
DefaultTimeNow()
}
})
}
func TestDefaultTimeNow(t *testing.T) {
// Get initial time
t1 := DefaultTimeNow()
// Second call should return the same time without calling time.Now
t2 := DefaultTimeNow()
if !t1.Equal(t2) {
t.Errorf("Expected t1 == t2, got t1 = %v, t2 = %v", t1, t2)
}
// Set shouldCallTimeNow to true
shouldCallTimeNow.Store(true)
// This should update the lastTimeNow
t3 := DefaultTimeNow()
// The time should have changed
if t2.Equal(t3) {
t.Errorf("Expected t2 != t3, got t2 = %v, t3 = %v", t2, t3)
}
// Fourth call should return the same time as third call
t4 := DefaultTimeNow()
if !t3.Equal(t4) {
t.Errorf("Expected t3 == t4, got t3 = %v, t4 = %v", t3, t4)
}
}
func TestMockTimeNow(t *testing.T) {
// Save the original TimeNow function to restore later
originalTimeNow := TimeNow
defer func() {
TimeNow = originalTimeNow
}()
// Create a fixed time
fixedTime := time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)
// Mock the time
MockTimeNow(fixedTime)
// TimeNow should return the fixed time
result := TimeNow()
if !result.Equal(fixedTime) {
t.Errorf("Expected %v, got %v", fixedTime, result)
}
}
func TestTimeNowTicker(t *testing.T) {
// This test verifies that the ticker properly updates shouldCallTimeNow
// Reset the flag
shouldCallTimeNow.Store(false)
// Wait for the ticker to tick (slightly more than the interval)
time.Sleep(shouldCallTimeNowInterval + 10*time.Millisecond)
// The ticker should have set shouldCallTimeNow to true
if !shouldCallTimeNow.Load() {
t.Error("Expected shouldCallTimeNow to be true after ticker interval")
}
// Call DefaultTimeNow which should reset the flag
DefaultTimeNow()
// Check that the flag is reset
if shouldCallTimeNow.Load() {
t.Error("Expected shouldCallTimeNow to be false after calling DefaultTimeNow")
}
}
/*
BenchmarkTimeNow
BenchmarkTimeNow/default
BenchmarkTimeNow/default-20 48158628 24.86 ns/op 0 B/op 0 allocs/op
BenchmarkTimeNow/reduced_call
BenchmarkTimeNow/reduced_call-20 1000000000 1.000 ns/op 0 B/op 0 allocs/op
*/

View file

@ -0,0 +1,11 @@
package accesslog
const (
kilobyte = 1024
megabyte = 1024 * kilobyte
gigabyte = 1024 * megabyte
kilobits = 1000
megabits = 1000 * kilobits
gigabits = 1000 * megabits
)

View file

@ -84,7 +84,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
if s.UseAccessLog() {
var err error
s.accessLogger, err = accesslog.NewFileAccessLogger(s.task, s.AccessLog)
s.accessLogger, err = accesslog.NewAccessLogger(s.task, s.AccessLog)
if err != nil {
s.task.Finish(err)
return gperr.Wrap(err)

View file

@ -111,7 +111,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
if r.UseAccessLog() {
var err error
r.rp.AccessLogger, err = accesslog.NewFileAccessLogger(r.task, r.AccessLog)
r.rp.AccessLogger, err = accesslog.NewAccessLogger(r.task, r.AccessLog)
if err != nil {
r.task.Finish(err)
return gperr.Wrap(err)

View file

@ -0,0 +1,42 @@
package synk
import "sync"
type (
// Pool is a wrapper of sync.Pool that limits the size of the object.
Pool[T any] struct {
pool sync.Pool
maxSize int
}
BytesPool = Pool[byte]
)
const (
DefaultInitBytes = 1024
DefaultMaxBytes = 1024 * 1024
)
func NewPool[T any](initSize int, maxSize int) *Pool[T] {
return &Pool[T]{
pool: sync.Pool{
New: func() any {
return make([]T, 0, initSize)
},
},
maxSize: maxSize,
}
}
func NewBytesPool(initSize int, maxSize int) *BytesPool {
return NewPool[byte](initSize, maxSize)
}
func (p *Pool[T]) Get() []T {
return p.pool.Get().([]T)
}
func (p *Pool[T]) Put(b []T) {
if cap(b) <= p.maxSize {
p.pool.Put(b[:0])
}
}