mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-09 07:54:03 +02:00
feat(acl): connection level ip/geo blocking
- fixed access log logic - implement acl at connection level - acl logging - ip/cidr blocking - geoblocking with MaxMind database
This commit is contained in:
parent
e513db62b0
commit
b427ff1f88
32 changed files with 1359 additions and 193 deletions
|
@ -40,5 +40,5 @@ func StartAgentServer(parent task.Parent, opt Options) {
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
server.Start(parent, agentServer, logger)
|
server.Start(parent, agentServer, nil, logger)
|
||||||
}
|
}
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -22,7 +22,7 @@ require (
|
||||||
golang.org/x/oauth2 v0.29.0 // oauth2 authentication
|
golang.org/x/oauth2 v0.29.0 // oauth2 authentication
|
||||||
golang.org/x/text v0.24.0 // string utilities
|
golang.org/x/text v0.24.0 // string utilities
|
||||||
golang.org/x/time v0.11.0 // time utilities
|
golang.org/x/time v0.11.0 // time utilities
|
||||||
gopkg.in/yaml.v3 v3.0.1 // yaml parsing for different config files
|
gopkg.in/yaml.v3 v3.0.1 // indirect; yaml parsing for different config files
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2
|
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2
|
||||||
|
@ -30,8 +30,10 @@ replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2
|
||||||
require (
|
require (
|
||||||
github.com/bytedance/sonic v1.13.2
|
github.com/bytedance/sonic v1.13.2
|
||||||
github.com/docker/cli v28.1.1+incompatible
|
github.com/docker/cli v28.1.1+incompatible
|
||||||
|
github.com/goccy/go-yaml v1.17.1
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||||
github.com/luthermonson/go-proxmox v0.2.2
|
github.com/luthermonson/go-proxmox v0.2.2
|
||||||
|
github.com/oschwald/maxminddb-golang v1.13.1
|
||||||
github.com/quic-go/quic-go v0.51.0
|
github.com/quic-go/quic-go v0.51.0
|
||||||
github.com/samber/slog-zerolog/v2 v2.7.3
|
github.com/samber/slog-zerolog/v2 v2.7.3
|
||||||
github.com/spf13/afero v1.14.0
|
github.com/spf13/afero v1.14.0
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -77,6 +77,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||||
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
|
github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY=
|
||||||
|
github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e h1:LEbMtJ6loEubxetD+Aw8+1x0rShor5iMoy9WuFQ8hN8=
|
github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e h1:LEbMtJ6loEubxetD+Aw8+1x0rShor5iMoy9WuFQ8hN8=
|
||||||
github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e/go.mod h1:3tMTnTkH7IN5smn7PX83XdmRnNj4Nw2/Pt8GgReqnKM=
|
github.com/godoxy-app/docker v0.0.0-20250418000134-7af8fd7b079e/go.mod h1:3tMTnTkH7IN5smn7PX83XdmRnNj4Nw2/Pt8GgReqnKM=
|
||||||
|
@ -163,6 +165,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
|
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
|
||||||
|
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
|
||||||
github.com/ovh/go-ovh v1.7.0 h1:V14nF7FwDjQrZt9g7jzcvAAQ3HN6DNShRFRMC3jLoPw=
|
github.com/ovh/go-ovh v1.7.0 h1:V14nF7FwDjQrZt9g7jzcvAAQ3HN6DNShRFRMC3jLoPw=
|
||||||
github.com/ovh/go-ovh v1.7.0/go.mod h1:cTVDnl94z4tl8pP1uZ/8jlVxntjSIf09bNcQ5TJSC7c=
|
github.com/ovh/go-ovh v1.7.0/go.mod h1:cTVDnl94z4tl8pP1uZ/8jlVxntjSIf09bNcQ5TJSC7c=
|
||||||
github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc=
|
github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc=
|
||||||
|
|
39
internal/acl/city_cache.go
Normal file
39
internal/acl/city_cache.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
|
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cityCache = xsync.NewMapOf[string, *acl.City]()
|
||||||
|
var numCachedLookup atomic.Uint64
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) {
|
||||||
|
if ip.City != nil {
|
||||||
|
return ip.City, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.db.Reader == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
city, ok := cityCache.Load(ip.Str)
|
||||||
|
if ok {
|
||||||
|
numCachedLookup.Inc()
|
||||||
|
return city, true
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.db.RLock()
|
||||||
|
defer cfg.db.RUnlock()
|
||||||
|
|
||||||
|
city = new(acl.City)
|
||||||
|
err := cfg.db.Lookup(ip.IP, city)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cityCache.Store(ip.Str, city)
|
||||||
|
ip.City = city
|
||||||
|
return city, true
|
||||||
|
}
|
215
internal/acl/config.go
Normal file
215
internal/acl/config.go
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||||
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow
|
||||||
|
AllowLocal *bool `json:"allow_local"` // default: true
|
||||||
|
Allow []string `json:"allow"`
|
||||||
|
Deny []string `json:"deny"`
|
||||||
|
Log *accesslog.ACLLoggerConfig `json:"log"`
|
||||||
|
|
||||||
|
MaxMind *MaxMindConfig `json:"maxmind" validate:"omitempty"`
|
||||||
|
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
MaxMindDatabaseType string
|
||||||
|
MaxMindConfig struct {
|
||||||
|
AccountID string `json:"account_id" validate:"required"`
|
||||||
|
LicenseKey string `json:"license_key" validate:"required"`
|
||||||
|
Database MaxMindDatabaseType `json:"database" validate:"required,oneof=geolite geoip2"`
|
||||||
|
|
||||||
|
logger zerolog.Logger
|
||||||
|
lastUpdate time.Time
|
||||||
|
db struct {
|
||||||
|
*maxminddb.Reader
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
defaultAllow bool
|
||||||
|
allowLocal bool
|
||||||
|
allow []matcher
|
||||||
|
deny []matcher
|
||||||
|
ipCache *xsync.MapOf[string, *checkCache]
|
||||||
|
logAllowed bool
|
||||||
|
logger *accesslog.AccessLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
type checkCache struct {
|
||||||
|
*acl.IPInfo
|
||||||
|
allow bool
|
||||||
|
created time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
const cacheTTL = 1 * time.Minute
|
||||||
|
|
||||||
|
func (c *checkCache) Expired() bool {
|
||||||
|
return c.created.Add(cacheTTL).After(utils.TimeNow())
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: add stats
|
||||||
|
|
||||||
|
const (
|
||||||
|
ACLAllow = "allow"
|
||||||
|
ACLDeny = "deny"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxMindGeoLite MaxMindDatabaseType = "geolite"
|
||||||
|
MaxMindGeoIP2 MaxMindDatabaseType = "geoip2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Config) Validate() gperr.Error {
|
||||||
|
switch c.Default {
|
||||||
|
case "", ACLAllow:
|
||||||
|
c.defaultAllow = true
|
||||||
|
case ACLDeny:
|
||||||
|
c.defaultAllow = false
|
||||||
|
default:
|
||||||
|
return gperr.New("invalid default value").Subject(c.Default)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.AllowLocal != nil {
|
||||||
|
c.allowLocal = *c.AllowLocal
|
||||||
|
} else {
|
||||||
|
c.allowLocal = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MaxMind != nil {
|
||||||
|
c.MaxMind.logger = logging.With().Str("type", string(c.MaxMind.Database)).Logger()
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Log != nil {
|
||||||
|
c.logAllowed = c.Log.LogAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
errs := gperr.NewBuilder("syntax error")
|
||||||
|
c.allow = make([]matcher, 0, len(c.Allow))
|
||||||
|
c.deny = make([]matcher, 0, len(c.Deny))
|
||||||
|
|
||||||
|
for _, s := range c.Allow {
|
||||||
|
m, err := c.parseMatcher(s)
|
||||||
|
if err != nil {
|
||||||
|
errs.Add(err.Subject(s))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.allow = append(c.allow, m)
|
||||||
|
}
|
||||||
|
for _, s := range c.Deny {
|
||||||
|
m, err := c.parseMatcher(s)
|
||||||
|
if err != nil {
|
||||||
|
errs.Add(err.Subject(s))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.deny = append(c.deny, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errs.HasError() {
|
||||||
|
c.allow = nil
|
||||||
|
c.deny = nil
|
||||||
|
return errMatcherFormat.With(errs.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ipCache = xsync.NewMapOf[string, *checkCache]()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Valid() bool {
|
||||||
|
return c != nil && (len(c.allow) > 0 || len(c.deny) > 0 || c.allowLocal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Start(parent *task.Task) gperr.Error {
|
||||||
|
if c.MaxMind != nil {
|
||||||
|
if err := c.MaxMind.LoadMaxMindDB(parent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Log != nil {
|
||||||
|
logger, err := accesslog.NewAccessLogger(parent, c.Log)
|
||||||
|
if err != nil {
|
||||||
|
return gperr.New("failed to start access logger").With(err)
|
||||||
|
}
|
||||||
|
c.logger = logger
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) cacheRecord(info *acl.IPInfo, allow bool) {
|
||||||
|
c.ipCache.Store(info.Str, &checkCache{
|
||||||
|
IPInfo: info,
|
||||||
|
allow: allow,
|
||||||
|
created: utils.TimeNow(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *config) log(info *acl.IPInfo, allowed bool) {
|
||||||
|
if c.logger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !allowed || c.logAllowed {
|
||||||
|
c.logger.LogACL(info, !allowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) IPAllowed(ip net.IP) bool {
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// always allow private and loopback
|
||||||
|
// loopback is not logged
|
||||||
|
if ip.IsLoopback() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.allowLocal && ip.IsPrivate() {
|
||||||
|
c.log(&acl.IPInfo{IP: ip, Str: ip.String()}, true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
ipStr := ip.String()
|
||||||
|
record, ok := c.ipCache.Load(ipStr)
|
||||||
|
if ok && !record.Expired() {
|
||||||
|
c.log(record.IPInfo, record.allow)
|
||||||
|
return record.allow
|
||||||
|
}
|
||||||
|
|
||||||
|
ipAndStr := &acl.IPInfo{IP: ip, Str: ipStr}
|
||||||
|
for _, m := range c.allow {
|
||||||
|
if m(ipAndStr) {
|
||||||
|
c.log(ipAndStr, true)
|
||||||
|
c.cacheRecord(ipAndStr, true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range c.deny {
|
||||||
|
if m(ipAndStr) {
|
||||||
|
c.log(ipAndStr, false)
|
||||||
|
c.cacheRecord(ipAndStr, false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log(ipAndStr, c.defaultAllow)
|
||||||
|
c.cacheRecord(ipAndStr, c.defaultAllow)
|
||||||
|
return c.defaultAllow
|
||||||
|
}
|
99
internal/acl/matcher.go
Normal file
99
internal/acl/matcher.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type matcher func(*acl.IPInfo) bool
|
||||||
|
|
||||||
|
const (
|
||||||
|
MatcherTypeIP = "ip"
|
||||||
|
MatcherTypeCIDR = "cidr"
|
||||||
|
MatcherTypeTimeZone = "tz"
|
||||||
|
MatcherTypeISO = "iso"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errMatcherFormat = gperr.Multiline().AddLines(
|
||||||
|
"invalid matcher format, expect {type}:{value}",
|
||||||
|
"Available types: ip|cidr|tz|iso",
|
||||||
|
"ip:127.0.0.1",
|
||||||
|
"cidr:127.0.0.0/8",
|
||||||
|
"tz:Asia/Shanghai",
|
||||||
|
"iso:GB",
|
||||||
|
)
|
||||||
|
var (
|
||||||
|
errSyntax = gperr.New("syntax error")
|
||||||
|
errInvalidIP = gperr.New("invalid IP")
|
||||||
|
errInvalidCIDR = gperr.New("invalid CIDR")
|
||||||
|
errMaxMindNotConfigured = gperr.New("MaxMind not configured")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) {
|
||||||
|
parts := strings.Split(s, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, errSyntax
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parts[0] {
|
||||||
|
case MatcherTypeIP:
|
||||||
|
ip := net.ParseIP(parts[1])
|
||||||
|
if ip == nil {
|
||||||
|
return nil, errInvalidIP
|
||||||
|
}
|
||||||
|
return matchIP(ip), nil
|
||||||
|
case MatcherTypeCIDR:
|
||||||
|
_, net, err := net.ParseCIDR(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errInvalidCIDR
|
||||||
|
}
|
||||||
|
return matchCIDR(net), nil
|
||||||
|
case MatcherTypeTimeZone:
|
||||||
|
if cfg.MaxMind == nil {
|
||||||
|
return nil, errMaxMindNotConfigured
|
||||||
|
}
|
||||||
|
return cfg.MaxMind.matchTimeZone(parts[1]), nil
|
||||||
|
case MatcherTypeISO:
|
||||||
|
if cfg.MaxMind == nil {
|
||||||
|
return nil, errMaxMindNotConfigured
|
||||||
|
}
|
||||||
|
return cfg.MaxMind.matchISO(parts[1]), nil
|
||||||
|
default:
|
||||||
|
return nil, errSyntax
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchIP(ip net.IP) matcher {
|
||||||
|
return func(ip2 *acl.IPInfo) bool {
|
||||||
|
return ip.Equal(ip2.IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchCIDR(n *net.IPNet) matcher {
|
||||||
|
return func(ip *acl.IPInfo) bool {
|
||||||
|
return n.Contains(ip.IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher {
|
||||||
|
return func(ip *acl.IPInfo) bool {
|
||||||
|
city, ok := cfg.lookupCity(ip)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return city.Location.TimeZone == tz
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) matchISO(iso string) matcher {
|
||||||
|
return func(ip *acl.IPInfo) bool {
|
||||||
|
city, ok := cfg.lookupCity(ip)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return city.Country.IsoCode == iso
|
||||||
|
}
|
||||||
|
}
|
281
internal/acl/maxmind.go
Normal file
281
internal/acl/maxmind.go
Normal file
|
@ -0,0 +1,281 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"compress/gzip"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
updateInterval = 24 * time.Hour
|
||||||
|
httpClient = &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
ErrResponseNotOK = gperr.New("response not OK")
|
||||||
|
ErrDownloadFailure = gperr.New("download failure")
|
||||||
|
)
|
||||||
|
|
||||||
|
func dbPathImpl(dbType MaxMindDatabaseType) string {
|
||||||
|
if dbType == MaxMindGeoLite {
|
||||||
|
return filepath.Join(dataDir, "GeoLite2-City.mmdb")
|
||||||
|
}
|
||||||
|
return filepath.Join(dataDir, "GeoIP2-City.mmdb")
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbURLimpl(dbType MaxMindDatabaseType) string {
|
||||||
|
if dbType == MaxMindGeoLite {
|
||||||
|
return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"
|
||||||
|
}
|
||||||
|
return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbFilename(dbType MaxMindDatabaseType) string {
|
||||||
|
if dbType == MaxMindGeoLite {
|
||||||
|
return "GeoLite2-City.mmdb"
|
||||||
|
}
|
||||||
|
return "GeoIP2-City.mmdb"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error {
|
||||||
|
if cfg.Database == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
path := dbPath(cfg.Database)
|
||||||
|
reader, err := maxmindDBOpen(path)
|
||||||
|
exists := true
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, os.ErrNotExist):
|
||||||
|
default:
|
||||||
|
// ignore invalid error, just download it again
|
||||||
|
var invalidErr maxminddb.InvalidDatabaseError
|
||||||
|
if !errors.As(err, &invalidErr) {
|
||||||
|
return gperr.Wrap(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
cfg.logger.Info().Msg("MaxMind DB not found/invalid, downloading...")
|
||||||
|
reader, err = cfg.download()
|
||||||
|
if err != nil {
|
||||||
|
return ErrDownloadFailure.With(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg.logger.Info().Msg("MaxMind DB loaded")
|
||||||
|
|
||||||
|
cfg.db.Reader = reader
|
||||||
|
go cfg.scheduleUpdate(parent)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) loadLastUpdate() {
|
||||||
|
f, err := os.Stat(dbPath(cfg.Database))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.lastUpdate = f.ModTime()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) setLastUpdate(t time.Time) {
|
||||||
|
cfg.lastUpdate = t
|
||||||
|
_ = os.Chtimes(dbPath(cfg.Database), t, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) {
|
||||||
|
task := parent.Subtask("schedule_update", true)
|
||||||
|
ticker := time.NewTicker(updateInterval)
|
||||||
|
|
||||||
|
cfg.loadLastUpdate()
|
||||||
|
cfg.update()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
if cfg.db.Reader != nil {
|
||||||
|
cfg.db.Reader.Close()
|
||||||
|
}
|
||||||
|
task.Finish(nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-task.Context().Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cfg.update()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) update() {
|
||||||
|
// check for update
|
||||||
|
cfg.logger.Info().Msg("checking for MaxMind DB update...")
|
||||||
|
remoteLastModified, err := cfg.checkLastest()
|
||||||
|
if err != nil {
|
||||||
|
cfg.logger.Err(err).Msg("failed to check MaxMind DB update")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if remoteLastModified.Equal(cfg.lastUpdate) {
|
||||||
|
cfg.logger.Info().Msg("MaxMind DB is up to date")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.logger.Info().
|
||||||
|
Time("latest", remoteLastModified.Local()).
|
||||||
|
Time("current", cfg.lastUpdate).
|
||||||
|
Msg("MaxMind DB update available")
|
||||||
|
reader, err := cfg.download()
|
||||||
|
if err != nil {
|
||||||
|
cfg.logger.Err(err).Msg("failed to update MaxMind DB")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.db.Lock()
|
||||||
|
cfg.db.Close()
|
||||||
|
cfg.db.Reader = reader
|
||||||
|
cfg.setLastUpdate(*remoteLastModified)
|
||||||
|
cfg.db.Unlock()
|
||||||
|
|
||||||
|
cfg.logger.Info().Msg("MaxMind DB updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) newReq(method string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest(method, dbURL(cfg.Database), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) {
|
||||||
|
resp, err := newReq(cfg, http.MethodHead)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastModified := resp.Header.Get("Last-Modified")
|
||||||
|
if lastModified == "" {
|
||||||
|
cfg.logger.Warn().Msg("MaxMind responded no last modified time, update skipped")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified)
|
||||||
|
if err != nil {
|
||||||
|
cfg.logger.Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &lastModifiedTime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *MaxMindConfig) download() (*maxminddb.Reader, error) {
|
||||||
|
resp, err := newReq(cfg, http.MethodGet)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := dbPath(cfg.Database)
|
||||||
|
tmpPath := path + "-tmp.tar.gz"
|
||||||
|
file, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.logger.Info().Msg("MaxMind DB downloading...")
|
||||||
|
|
||||||
|
_, err = io.Copy(file, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
file.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file.Close()
|
||||||
|
|
||||||
|
// extract .tar.gz and move only the dbFilename to path
|
||||||
|
err = extractFileFromTarGz(tmpPath, dbFilename(cfg.Database), path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, gperr.New("failed to extract database from archive").With(err)
|
||||||
|
}
|
||||||
|
// cleanup the tar.gz file
|
||||||
|
_ = os.Remove(tmpPath)
|
||||||
|
|
||||||
|
db, err := maxmindDBOpen(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractFileFromTarGz(tarGzPath, targetFilename, destPath string) error {
|
||||||
|
f, err := os.Open(tarGzPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
gzr, err := gzip.NewReader(f)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer gzr.Close()
|
||||||
|
|
||||||
|
tr := tar.NewReader(gzr)
|
||||||
|
for {
|
||||||
|
hdr, err := tr.Next()
|
||||||
|
if err == io.EOF {
|
||||||
|
break // End of archive
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Only extract the file that matches targetFilename (basename match)
|
||||||
|
if filepath.Base(hdr.Name) == targetFilename {
|
||||||
|
outFile, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, hdr.FileInfo().Mode())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer outFile.Close()
|
||||||
|
_, err = io.Copy(outFile, tr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil // Done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("file %s not found in archive", targetFilename)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
dataDir = common.DataDir
|
||||||
|
dbURL = dbURLimpl
|
||||||
|
dbPath = dbPathImpl
|
||||||
|
maxmindDBOpen = maxminddb.Open
|
||||||
|
newReq = (*MaxMindConfig).newReq
|
||||||
|
)
|
213
internal/acl/maxmind_test.go
Normal file
213
internal/acl/maxmind_test.go
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_dbPath(t *testing.T) {
|
||||||
|
tmpDataDir := "/tmp/testdata"
|
||||||
|
oldDataDir := dataDir
|
||||||
|
dataDir = tmpDataDir
|
||||||
|
defer func() { dataDir = oldDataDir }()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dbType MaxMindDatabaseType
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"GeoLite", MaxMindGeoLite, filepath.Join(tmpDataDir, "GeoLite2-City.mmdb")},
|
||||||
|
{"GeoIP2", MaxMindGeoIP2, filepath.Join(tmpDataDir, "GeoIP2-City.mmdb")},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := dbPath(tt.dbType); got != tt.want {
|
||||||
|
t.Errorf("dbPath() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_dbURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dbType MaxMindDatabaseType
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"GeoLite", MaxMindGeoLite, "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"},
|
||||||
|
{"GeoIP2", MaxMindGeoIP2, "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := dbURL(tt.dbType); got != tt.want {
|
||||||
|
t.Errorf("dbURL() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper for MaxMindConfig ---
|
||||||
|
type testLogger struct{ zerolog.Logger }
|
||||||
|
|
||||||
|
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
|
||||||
|
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
|
||||||
|
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
|
||||||
|
|
||||||
|
func Test_MaxMindConfig_newReq(t *testing.T) {
|
||||||
|
cfg := &MaxMindConfig{
|
||||||
|
AccountID: "testid",
|
||||||
|
LicenseKey: "testkey",
|
||||||
|
Database: MaxMindGeoLite,
|
||||||
|
logger: zerolog.Nop(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch httpClient to use httptest
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" {
|
||||||
|
t.Errorf("basic auth not set correctly")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
oldURL := dbURL
|
||||||
|
dbURL = func(MaxMindDatabaseType) string { return server.URL }
|
||||||
|
defer func() { dbURL = oldURL }()
|
||||||
|
|
||||||
|
resp, err := cfg.newReq(http.MethodGet)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newReq() error = %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("unexpected status: %v", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_MaxMindConfig_checkUpdate(t *testing.T) {
|
||||||
|
cfg := &MaxMindConfig{
|
||||||
|
AccountID: "id",
|
||||||
|
LicenseKey: "key",
|
||||||
|
Database: MaxMindGeoLite,
|
||||||
|
logger: zerolog.Nop(),
|
||||||
|
}
|
||||||
|
lastMod := time.Now().UTC().Format(http.TimeFormat)
|
||||||
|
buildTime := time.Now().Add(-time.Hour)
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Last-Modified", lastMod)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
oldURL := dbURL
|
||||||
|
dbURL = func(MaxMindDatabaseType) string { return server.URL }
|
||||||
|
defer func() { dbURL = oldURL }()
|
||||||
|
|
||||||
|
latest, err := cfg.checkLastest()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("checkUpdate() error = %v", err)
|
||||||
|
}
|
||||||
|
if latest.Equal(buildTime) {
|
||||||
|
t.Errorf("expected update needed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeReadCloser struct {
|
||||||
|
firstRead bool
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeReadCloser) Read(p []byte) (int, error) {
|
||||||
|
if !c.firstRead {
|
||||||
|
c.firstRead = true
|
||||||
|
return strings.NewReader("FAKEMMDB").Read(p)
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeReadCloser) Close() error {
|
||||||
|
c.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_MaxMindConfig_download(t *testing.T) {
|
||||||
|
cfg := &MaxMindConfig{
|
||||||
|
AccountID: "id",
|
||||||
|
LicenseKey: "key",
|
||||||
|
Database: MaxMindGeoLite,
|
||||||
|
logger: zerolog.Nop(),
|
||||||
|
}
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
io.Copy(w, strings.NewReader("FAKEMMDB"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
oldURL := dbURL
|
||||||
|
dbURL = func(MaxMindDatabaseType) string { return server.URL }
|
||||||
|
defer func() { dbURL = oldURL }()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
oldDataDir := dataDir
|
||||||
|
dataDir = tmpDir
|
||||||
|
defer func() { dataDir = oldDataDir }()
|
||||||
|
|
||||||
|
// Patch maxminddb.Open to always succeed
|
||||||
|
origOpen := maxmindDBOpen
|
||||||
|
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
||||||
|
return &maxminddb.Reader{}, nil
|
||||||
|
}
|
||||||
|
defer func() { maxmindDBOpen = origOpen }()
|
||||||
|
|
||||||
|
rw := &fakeReadCloser{}
|
||||||
|
oldNewReq := newReq
|
||||||
|
newReq = func(cfg *MaxMindConfig, method string) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: rw,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
defer func() { newReq = oldNewReq }()
|
||||||
|
|
||||||
|
db, err := cfg.download()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("download() error = %v", err)
|
||||||
|
}
|
||||||
|
if db == nil {
|
||||||
|
t.Error("expected db instance")
|
||||||
|
}
|
||||||
|
if !rw.closed {
|
||||||
|
t.Error("expected rw to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
|
||||||
|
// This test should cover both the path where DB exists and where it does not
|
||||||
|
// For brevity, only the non-existing path is tested here
|
||||||
|
cfg := &MaxMindConfig{
|
||||||
|
AccountID: "id",
|
||||||
|
LicenseKey: "key",
|
||||||
|
Database: MaxMindGeoLite,
|
||||||
|
logger: zerolog.Nop(),
|
||||||
|
}
|
||||||
|
oldOpen := maxmindDBOpen
|
||||||
|
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
|
||||||
|
return &maxminddb.Reader{}, nil
|
||||||
|
}
|
||||||
|
defer func() { maxmindDBOpen = oldOpen }()
|
||||||
|
|
||||||
|
oldDBPath := dbPath
|
||||||
|
dbPath = func(MaxMindDatabaseType) string { return filepath.Join(t.TempDir(), "maxmind.mmdb") }
|
||||||
|
defer func() { dbPath = oldDBPath }()
|
||||||
|
|
||||||
|
task := task.RootTask("test")
|
||||||
|
defer task.Finish(nil)
|
||||||
|
err := cfg.LoadMaxMindDB(task)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("loadMaxMindDB() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
46
internal/acl/tcp_listener.go
Normal file
46
internal/acl/tcp_listener.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TCPListener struct {
|
||||||
|
acl *Config
|
||||||
|
lis net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) WrapTCP(lis net.Listener) net.Listener {
|
||||||
|
if cfg == nil {
|
||||||
|
return lis
|
||||||
|
}
|
||||||
|
return &TCPListener{
|
||||||
|
acl: cfg,
|
||||||
|
lis: lis,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPListener) Addr() net.Addr {
|
||||||
|
return s.lis.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPListener) Accept() (net.Conn, error) {
|
||||||
|
c, err := s.lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
addr, ok := c.RemoteAddr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
// Not a TCPAddr, drop
|
||||||
|
c.Close()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if !s.acl.IPAllowed(addr.IP) {
|
||||||
|
c.Close()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPListener) Close() error {
|
||||||
|
return s.lis.Close()
|
||||||
|
}
|
10
internal/acl/types/city_info.go
Normal file
10
internal/acl/types/city_info.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
type City struct {
|
||||||
|
Location struct {
|
||||||
|
TimeZone string `maxminddb:"time_zone"`
|
||||||
|
} `maxminddb:"location"`
|
||||||
|
Country struct {
|
||||||
|
IsoCode string `maxminddb:"iso_code"`
|
||||||
|
} `maxminddb:"country"`
|
||||||
|
}
|
9
internal/acl/types/ip_info.go
Normal file
9
internal/acl/types/ip_info.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
type IPInfo struct {
|
||||||
|
IP net.IP
|
||||||
|
Str string
|
||||||
|
City *City
|
||||||
|
}
|
79
internal/acl/udp_listener.go
Normal file
79
internal/acl/udp_listener.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package acl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UDPListener struct {
|
||||||
|
acl *Config
|
||||||
|
lis net.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) WrapUDP(lis net.PacketConn) net.PacketConn {
|
||||||
|
if cfg == nil {
|
||||||
|
return lis
|
||||||
|
}
|
||||||
|
return &UDPListener{
|
||||||
|
acl: cfg,
|
||||||
|
lis: lis,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) LocalAddr() net.Addr {
|
||||||
|
return s.lis.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||||
|
for {
|
||||||
|
n, addr, err := s.lis.ReadFrom(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, addr, err
|
||||||
|
}
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
// Not a UDPAddr, drop
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.acl.IPAllowed(udpAddr.IP) {
|
||||||
|
// Drop packet from disallowed IP
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, addr, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||||
|
for {
|
||||||
|
n, err := s.lis.WriteTo(p, addr)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
// Not a UDPAddr, drop
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !s.acl.IPAllowed(udpAddr.IP) {
|
||||||
|
// Drop packet to disallowed IP
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) SetDeadline(t time.Time) error {
|
||||||
|
return s.lis.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) SetReadDeadline(t time.Time) error {
|
||||||
|
return s.lis.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) SetWriteDeadline(t time.Time) error {
|
||||||
|
return s.lis.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPListener) Close() error {
|
||||||
|
return s.lis.Close()
|
||||||
|
}
|
|
@ -197,6 +197,7 @@ func (cfg *Config) StartServers(opts ...*StartServersOptions) {
|
||||||
HTTPAddr: common.ProxyHTTPAddr,
|
HTTPAddr: common.ProxyHTTPAddr,
|
||||||
HTTPSAddr: common.ProxyHTTPSAddr,
|
HTTPSAddr: common.ProxyHTTPSAddr,
|
||||||
Handler: cfg.entrypoint,
|
Handler: cfg.entrypoint,
|
||||||
|
ACL: cfg.value.ACL,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if opt.API {
|
if opt.API {
|
||||||
|
@ -237,6 +238,14 @@ func (cfg *Config) load() gperr.Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg.entrypoint.SetFindRouteDomains(model.MatchDomains)
|
cfg.entrypoint.SetFindRouteDomains(model.MatchDomains)
|
||||||
|
if model.ACL.Valid() {
|
||||||
|
err := model.ACL.Start(cfg.task)
|
||||||
|
if err != nil {
|
||||||
|
errs.Add(err)
|
||||||
|
} else {
|
||||||
|
logging.Info().Msg("ACL started")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package types
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -7,15 +7,17 @@ import (
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||||
|
"github.com/yusing/go-proxy/internal/acl"
|
||||||
"github.com/yusing/go-proxy/internal/autocert"
|
"github.com/yusing/go-proxy/internal/autocert"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/accesslog"
|
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/notif"
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Config struct {
|
Config struct {
|
||||||
|
ACL *acl.Config `json:"acl"`
|
||||||
AutoCert *autocert.AutocertConfig `json:"autocert"`
|
AutoCert *autocert.AutocertConfig `json:"autocert"`
|
||||||
Entrypoint Entrypoint `json:"entrypoint"`
|
Entrypoint Entrypoint `json:"entrypoint"`
|
||||||
Providers Providers `json:"providers"`
|
Providers Providers `json:"providers"`
|
||||||
|
@ -31,7 +33,10 @@ type (
|
||||||
}
|
}
|
||||||
Entrypoint struct {
|
Entrypoint struct {
|
||||||
Middlewares []map[string]any `json:"middlewares"`
|
Middlewares []map[string]any `json:"middlewares"`
|
||||||
AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"`
|
AccessLog *accesslog.RequestLoggerConfig `json:"access_log" validate:"omitempty"`
|
||||||
|
}
|
||||||
|
HomepageConfig struct {
|
||||||
|
UseDefaultCategories bool `json:"use_default_categories"`
|
||||||
}
|
}
|
||||||
|
|
||||||
ConfigInstance interface {
|
ConfigInstance interface {
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
package types
|
|
||||||
|
|
||||||
type HomepageConfig struct {
|
|
||||||
UseDefaultCategories bool `json:"use_default_categories"`
|
|
||||||
}
|
|
|
@ -54,7 +54,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
|
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) (err error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
ep.accessLogger = nil
|
ep.accessLogger = nil
|
||||||
return
|
return
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -19,9 +20,12 @@ type (
|
||||||
AccessLogger struct {
|
AccessLogger struct {
|
||||||
task *task.Task
|
task *task.Task
|
||||||
cfg *Config
|
cfg *Config
|
||||||
io AccessLogIO
|
|
||||||
buffered *bufio.Writer
|
closer []io.Closer
|
||||||
supportRotate bool
|
supportRotate []supportRotate
|
||||||
|
writer *bufio.Writer
|
||||||
|
writeLock sync.Mutex
|
||||||
|
closed bool
|
||||||
|
|
||||||
lineBufPool *synk.BytesPool // buffer pool for formatting a single log line
|
lineBufPool *synk.BytesPool // buffer pool for formatting a single log line
|
||||||
|
|
||||||
|
@ -29,85 +33,104 @@ type (
|
||||||
|
|
||||||
logger zerolog.Logger
|
logger zerolog.Logger
|
||||||
|
|
||||||
Formatter
|
RequestFormatter
|
||||||
|
ACLFormatter
|
||||||
}
|
}
|
||||||
|
|
||||||
AccessLogIO interface {
|
WriterWithName interface {
|
||||||
io.Writer
|
io.Writer
|
||||||
sync.Locker
|
|
||||||
Name() string // file name or path
|
Name() string // file name or path
|
||||||
}
|
}
|
||||||
|
|
||||||
Formatter interface {
|
RequestFormatter interface {
|
||||||
// AppendLog appends a log line to line with or without a trailing newline
|
// AppendRequestLog appends a log line to line with or without a trailing newline
|
||||||
AppendLog(line []byte, req *http.Request, res *http.Response) []byte
|
AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte
|
||||||
|
}
|
||||||
|
ACLFormatter interface {
|
||||||
|
// AppendACLLog appends a log line to line with or without a trailing newline
|
||||||
|
AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const MinBufferSize = 4 * kilobyte
|
const (
|
||||||
|
MinBufferSize = 4 * kilobyte
|
||||||
|
MaxBufferSize = 1 * megabyte
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
flushInterval = 30 * time.Second
|
flushInterval = 30 * time.Second
|
||||||
rotateInterval = time.Hour
|
rotateInterval = time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
|
const (
|
||||||
var ios []AccessLogIO
|
errRateLimit = 200 * time.Millisecond
|
||||||
|
errBurst = 5
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.Stdout {
|
func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
|
||||||
ios = append(ios, stdoutIO)
|
io, err := cfg.IO()
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Path != "" {
|
|
||||||
io, err := newFileIO(cfg.Path)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ios = append(ios, io)
|
return NewAccessLoggerWithIO(parent, io, cfg), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ios) == 0 {
|
func NewMockAccessLogger(parent task.Parent, cfg *RequestLoggerConfig) *AccessLogger {
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewAccessLoggerWithIO(parent, NewMultiWriter(ios...), cfg), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMockAccessLogger(parent task.Parent, cfg *Config) *AccessLogger {
|
|
||||||
return NewAccessLoggerWithIO(parent, NewMockFile(), cfg)
|
return NewAccessLoggerWithIO(parent, NewMockFile(), cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessLoggerWithIO(parent task.Parent, io AccessLogIO, cfg *Config) *AccessLogger {
|
func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg AnyConfig) *AccessLogger {
|
||||||
|
cfg := anyCfg.ToConfig()
|
||||||
if cfg.BufferSize == 0 {
|
if cfg.BufferSize == 0 {
|
||||||
cfg.BufferSize = DefaultBufferSize
|
cfg.BufferSize = DefaultBufferSize
|
||||||
}
|
}
|
||||||
if cfg.BufferSize < MinBufferSize {
|
if cfg.BufferSize < MinBufferSize {
|
||||||
cfg.BufferSize = MinBufferSize
|
cfg.BufferSize = MinBufferSize
|
||||||
}
|
}
|
||||||
|
if cfg.BufferSize > MaxBufferSize {
|
||||||
|
cfg.BufferSize = MaxBufferSize
|
||||||
|
}
|
||||||
l := &AccessLogger{
|
l := &AccessLogger{
|
||||||
task: parent.Subtask("accesslog."+io.Name(), true),
|
task: parent.Subtask("accesslog."+writer.Name(), true),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
io: io,
|
writer: bufio.NewWriterSize(writer, cfg.BufferSize),
|
||||||
buffered: bufio.NewWriterSize(io, cfg.BufferSize),
|
lineBufPool: synk.NewBytesPool(512, 8192),
|
||||||
lineBufPool: synk.NewBytesPool(1024, synk.DefaultMaxBytes),
|
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
|
||||||
errRateLimiter: rate.NewLimiter(rate.Every(time.Second), 1),
|
logger: logging.With().Str("file", writer.Name()).Logger(),
|
||||||
logger: logging.With().Str("file", io.Name()).Logger(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt := CommonFormatter{cfg: &l.cfg.Fields}
|
if unwrapped, ok := writer.(MultiWriterInterface); ok {
|
||||||
switch l.cfg.Format {
|
for _, w := range unwrapped.Unwrap() {
|
||||||
|
if sr, ok := w.(supportRotate); ok {
|
||||||
|
l.supportRotate = append(l.supportRotate, sr)
|
||||||
|
}
|
||||||
|
if closer, ok := w.(io.Closer); ok {
|
||||||
|
l.closer = append(l.closer, closer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if sr, ok := writer.(supportRotate); ok {
|
||||||
|
l.supportRotate = append(l.supportRotate, sr)
|
||||||
|
}
|
||||||
|
if closer, ok := writer.(io.Closer); ok {
|
||||||
|
l.closer = append(l.closer, closer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.req != nil {
|
||||||
|
fmt := CommonFormatter{cfg: &cfg.req.Fields}
|
||||||
|
switch cfg.req.Format {
|
||||||
case FormatCommon:
|
case FormatCommon:
|
||||||
l.Formatter = &fmt
|
l.RequestFormatter = &fmt
|
||||||
case FormatCombined:
|
case FormatCombined:
|
||||||
l.Formatter = &CombinedFormatter{fmt}
|
l.RequestFormatter = &CombinedFormatter{fmt}
|
||||||
case FormatJSON:
|
case FormatJSON:
|
||||||
l.Formatter = &JSONFormatter{fmt}
|
l.RequestFormatter = &JSONFormatter{fmt}
|
||||||
default: // should not happen, validation has done by validate tags
|
default: // should not happen, validation has done by validate tags
|
||||||
panic("invalid access log format")
|
panic("invalid access log format")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
if _, ok := l.io.(supportRotate); ok {
|
l.ACLFormatter = ACLLogFormatter{}
|
||||||
l.supportRotate = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go l.start()
|
go l.start()
|
||||||
|
@ -119,10 +142,10 @@ func (l *AccessLogger) Config() *Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool {
|
func (l *AccessLogger) shouldLog(req *http.Request, res *http.Response) bool {
|
||||||
if !l.cfg.Filters.StatusCodes.CheckKeep(req, res) ||
|
if !l.cfg.req.Filters.StatusCodes.CheckKeep(req, res) ||
|
||||||
!l.cfg.Filters.Method.CheckKeep(req, res) ||
|
!l.cfg.req.Filters.Method.CheckKeep(req, res) ||
|
||||||
!l.cfg.Filters.Headers.CheckKeep(req, res) ||
|
!l.cfg.req.Filters.Headers.CheckKeep(req, res) ||
|
||||||
!l.cfg.Filters.CIDR.CheckKeep(req, res) {
|
!l.cfg.req.Filters.CIDR.CheckKeep(req, res) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
@ -135,19 +158,29 @@ func (l *AccessLogger) Log(req *http.Request, res *http.Response) {
|
||||||
|
|
||||||
line := l.lineBufPool.Get()
|
line := l.lineBufPool.Get()
|
||||||
defer l.lineBufPool.Put(line)
|
defer l.lineBufPool.Put(line)
|
||||||
line = l.Formatter.AppendLog(line, req, res)
|
line = l.AppendRequestLog(line, req, res)
|
||||||
if line[len(line)-1] != '\n' {
|
if line[len(line)-1] != '\n' {
|
||||||
line = append(line, '\n')
|
line = append(line, '\n')
|
||||||
}
|
}
|
||||||
l.lockWrite(line)
|
l.write(line)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) LogError(req *http.Request, err error) {
|
func (l *AccessLogger) LogError(req *http.Request, err error) {
|
||||||
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
|
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *AccessLogger) LogACL(info *acl.IPInfo, blocked bool) {
|
||||||
|
line := l.lineBufPool.Get()
|
||||||
|
defer l.lineBufPool.Put(line)
|
||||||
|
line = l.ACLFormatter.AppendACLLog(line, info, blocked)
|
||||||
|
if line[len(line)-1] != '\n' {
|
||||||
|
line = append(line, '\n')
|
||||||
|
}
|
||||||
|
l.write(line)
|
||||||
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) ShouldRotate() bool {
|
func (l *AccessLogger) ShouldRotate() bool {
|
||||||
return l.cfg.Retention.IsValid() && l.supportRotate
|
return l.supportRotate != nil && l.cfg.Retention.IsValid()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
|
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
|
||||||
|
@ -155,10 +188,21 @@ func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
l.io.Lock()
|
l.writer.Flush()
|
||||||
defer l.io.Unlock()
|
l.writeLock.Lock()
|
||||||
|
defer l.writeLock.Unlock()
|
||||||
|
|
||||||
return rotateLogFile(l.io.(supportRotate), l.cfg.Retention)
|
result = new(RotateResult)
|
||||||
|
for _, sr := range l.supportRotate {
|
||||||
|
r, err := rotateLogFile(sr, l.cfg.Retention)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if r != nil {
|
||||||
|
result.Add(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) handleErr(err error) {
|
func (l *AccessLogger) handleErr(err error) {
|
||||||
|
@ -172,11 +216,9 @@ func (l *AccessLogger) handleErr(err error) {
|
||||||
|
|
||||||
func (l *AccessLogger) start() {
|
func (l *AccessLogger) start() {
|
||||||
defer func() {
|
defer func() {
|
||||||
defer l.task.Finish(nil)
|
l.Flush()
|
||||||
defer l.close()
|
l.Close()
|
||||||
if err := l.Flush(); err != nil {
|
l.task.Finish(nil)
|
||||||
l.handleErr(err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// flushes the buffer every 30 seconds
|
// flushes the buffer every 30 seconds
|
||||||
|
@ -191,9 +233,7 @@ func (l *AccessLogger) start() {
|
||||||
case <-l.task.Context().Done():
|
case <-l.task.Context().Done():
|
||||||
return
|
return
|
||||||
case <-flushTicker.C:
|
case <-flushTicker.C:
|
||||||
if err := l.Flush(); err != nil {
|
l.Flush()
|
||||||
l.handleErr(err)
|
|
||||||
}
|
|
||||||
case <-rotateTicker.C:
|
case <-rotateTicker.C:
|
||||||
if !l.ShouldRotate() {
|
if !l.ShouldRotate() {
|
||||||
continue
|
continue
|
||||||
|
@ -210,27 +250,40 @@ func (l *AccessLogger) start() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) Flush() error {
|
func (l *AccessLogger) Close() error {
|
||||||
l.io.Lock()
|
l.writeLock.Lock()
|
||||||
defer l.io.Unlock()
|
defer l.writeLock.Unlock()
|
||||||
return l.buffered.Flush()
|
if l.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if l.closer != nil {
|
||||||
|
for _, c := range l.closer {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.closed = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) close() {
|
func (l *AccessLogger) Flush() {
|
||||||
if r, ok := l.io.(io.Closer); ok {
|
l.writeLock.Lock()
|
||||||
l.io.Lock()
|
defer l.writeLock.Unlock()
|
||||||
defer l.io.Unlock()
|
if l.closed {
|
||||||
r.Close()
|
return
|
||||||
|
}
|
||||||
|
if err := l.writer.Flush(); err != nil {
|
||||||
|
l.handleErr(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *AccessLogger) lockWrite(data []byte) {
|
func (l *AccessLogger) write(data []byte) {
|
||||||
l.io.Lock() // prevent concurrent write, i.e. log rotation, other access loggers
|
l.writeLock.Lock()
|
||||||
_, err := l.buffered.Write(data)
|
defer l.writeLock.Unlock()
|
||||||
l.io.Unlock()
|
if l.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err := l.writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.handleErr(err)
|
l.handleErr(err)
|
||||||
} else {
|
|
||||||
logging.Trace().Msg("access log flushed to " + l.io.Name())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,18 +52,18 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func fmtLog(cfg *Config) (ts string, line string) {
|
func fmtLog(cfg *RequestLoggerConfig) (ts string, line string) {
|
||||||
buf := make([]byte, 0, 1024)
|
buf := make([]byte, 0, 1024)
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
logger := NewMockAccessLogger(testTask, cfg)
|
logger := NewMockAccessLogger(testTask, cfg)
|
||||||
utils.MockTimeNow(t)
|
utils.MockTimeNow(t)
|
||||||
buf = logger.AppendLog(buf, req, resp)
|
buf = logger.AppendRequestLog(buf, req, resp)
|
||||||
return t.Format(LogTimeFormat), string(buf)
|
return t.Format(LogTimeFormat), string(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerCommon(t *testing.T) {
|
func TestAccessLoggerCommon(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Format = FormatCommon
|
config.Format = FormatCommon
|
||||||
ts, log := fmtLog(config)
|
ts, log := fmtLog(config)
|
||||||
expect.Equal(t, log,
|
expect.Equal(t, log,
|
||||||
|
@ -74,7 +74,7 @@ func TestAccessLoggerCommon(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerCombined(t *testing.T) {
|
func TestAccessLoggerCombined(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Format = FormatCombined
|
config.Format = FormatCombined
|
||||||
ts, log := fmtLog(config)
|
ts, log := fmtLog(config)
|
||||||
expect.Equal(t, log,
|
expect.Equal(t, log,
|
||||||
|
@ -85,7 +85,7 @@ func TestAccessLoggerCombined(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerRedactQuery(t *testing.T) {
|
func TestAccessLoggerRedactQuery(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Format = FormatCommon
|
config.Format = FormatCommon
|
||||||
config.Fields.Query.Default = FieldModeRedact
|
config.Fields.Query.Default = FieldModeRedact
|
||||||
ts, log := fmtLog(config)
|
ts, log := fmtLog(config)
|
||||||
|
@ -115,7 +115,7 @@ type JSONLogEntry struct {
|
||||||
Cookies map[string]string `json:"cookies,omitempty"`
|
Cookies map[string]string `json:"cookies,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
|
func getJSONEntry(t *testing.T, config *RequestLoggerConfig) JSONLogEntry {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
config.Format = FormatJSON
|
config.Format = FormatJSON
|
||||||
var entry JSONLogEntry
|
var entry JSONLogEntry
|
||||||
|
@ -126,7 +126,7 @@ func getJSONEntry(t *testing.T, config *Config) JSONLogEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSON(t *testing.T) {
|
func TestAccessLoggerJSON(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
expect.Equal(t, entry.IP, remote)
|
expect.Equal(t, entry.IP, remote)
|
||||||
expect.Equal(t, entry.Method, method)
|
expect.Equal(t, entry.Method, method)
|
||||||
|
@ -147,7 +147,7 @@ func TestAccessLoggerJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAccessLoggerJSON(b *testing.B) {
|
func BenchmarkAccessLoggerJSON(b *testing.B) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Format = FormatJSON
|
config.Format = FormatJSON
|
||||||
logger := NewMockAccessLogger(testTask, config)
|
logger := NewMockAccessLogger(testTask, config)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -157,7 +157,7 @@ func BenchmarkAccessLoggerJSON(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAccessLoggerCombined(b *testing.B) {
|
func BenchmarkAccessLoggerCombined(b *testing.B) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Format = FormatCombined
|
config.Format = FormatCombined
|
||||||
logger := NewMockAccessLogger(testTask, config)
|
logger := NewMockAccessLogger(testTask, config)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
|
@ -6,9 +6,14 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ReaderAtSeeker interface {
|
||||||
|
io.ReaderAt
|
||||||
|
io.Seeker
|
||||||
|
}
|
||||||
|
|
||||||
// BackScanner provides an interface to read a file backward line by line.
|
// BackScanner provides an interface to read a file backward line by line.
|
||||||
type BackScanner struct {
|
type BackScanner struct {
|
||||||
file supportRotate
|
file ReaderAtSeeker
|
||||||
size int64
|
size int64
|
||||||
chunkSize int
|
chunkSize int
|
||||||
chunkBuf []byte
|
chunkBuf []byte
|
||||||
|
@ -21,7 +26,7 @@ type BackScanner struct {
|
||||||
|
|
||||||
// NewBackScanner creates a new Scanner to read the file backward.
|
// NewBackScanner creates a new Scanner to read the file backward.
|
||||||
// chunkSize determines the size of each read chunk from the end of the file.
|
// chunkSize determines the size of each read chunk from the end of the file.
|
||||||
func NewBackScanner(file supportRotate, chunkSize int) *BackScanner {
|
func NewBackScanner(file ReaderAtSeeker, chunkSize int) *BackScanner {
|
||||||
size, err := file.Seek(0, io.SeekEnd)
|
size, err := file.Seek(0, io.SeekEnd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &BackScanner{err: err}
|
return &BackScanner{err: err}
|
||||||
|
@ -29,7 +34,7 @@ func NewBackScanner(file supportRotate, chunkSize int) *BackScanner {
|
||||||
return newBackScanner(file, size, make([]byte, chunkSize))
|
return newBackScanner(file, size, make([]byte, chunkSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBackScanner(file supportRotate, fileSize int64, buf []byte) *BackScanner {
|
func newBackScanner(file ReaderAtSeeker, fileSize int64, buf []byte) *BackScanner {
|
||||||
return &BackScanner{
|
return &BackScanner{
|
||||||
file: file,
|
file: file,
|
||||||
size: fileSize,
|
size: fileSize,
|
||||||
|
|
|
@ -135,7 +135,7 @@ func TestBackScannerWithVaryingChunkSizes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func logEntry() []byte {
|
func logEntry() []byte {
|
||||||
accesslog := NewMockAccessLogger(task.RootTask("test", false), &Config{
|
accesslog := NewMockAccessLogger(task.RootTask("test", false), &RequestLoggerConfig{
|
||||||
Format: FormatJSON,
|
Format: FormatJSON,
|
||||||
})
|
})
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -148,7 +148,7 @@ func logEntry() []byte {
|
||||||
res := httptest.NewRecorder()
|
res := httptest.NewRecorder()
|
||||||
// server the request
|
// server the request
|
||||||
srv.Config.Handler.ServeHTTP(res, req)
|
srv.Config.Handler.ServeHTTP(res, req)
|
||||||
b := accesslog.AppendLog(nil, req, res.Result())
|
b := accesslog.AppendRequestLog(nil, req, res.Result())
|
||||||
if b[len(b)-1] != '\n' {
|
if b[len(b)-1] != '\n' {
|
||||||
b = append(b, '\n')
|
b = append(b, '\n')
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,32 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
ConfigBase struct {
|
||||||
|
BufferSize int `json:"buffer_size"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Stdout bool `json:"stdout"`
|
||||||
|
Retention *Retention `json:"retention" aliases:"keep"`
|
||||||
|
}
|
||||||
|
ACLLoggerConfig struct {
|
||||||
|
ConfigBase
|
||||||
|
LogAllowed bool `json:"log_allowed"`
|
||||||
|
}
|
||||||
|
RequestLoggerConfig struct {
|
||||||
|
ConfigBase
|
||||||
|
Format Format `json:"format" validate:"oneof=common combined json"`
|
||||||
|
Filters Filters `json:"filters"`
|
||||||
|
Fields Fields `json:"fields"`
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
*ConfigBase
|
||||||
|
acl *ACLLoggerConfig
|
||||||
|
req *RequestLoggerConfig
|
||||||
|
}
|
||||||
|
AnyConfig interface {
|
||||||
|
ToConfig() *Config
|
||||||
|
IO() (WriterWithName, error)
|
||||||
|
}
|
||||||
|
|
||||||
Format string
|
Format string
|
||||||
Filters struct {
|
Filters struct {
|
||||||
StatusCodes LogFilter[*StatusCodeRange] `json:"status_codes"`
|
StatusCodes LogFilter[*StatusCodeRange] `json:"status_codes"`
|
||||||
|
@ -19,15 +45,6 @@ type (
|
||||||
Query FieldConfig `json:"query"`
|
Query FieldConfig `json:"query"`
|
||||||
Cookies FieldConfig `json:"cookies"`
|
Cookies FieldConfig `json:"cookies"`
|
||||||
}
|
}
|
||||||
Config struct {
|
|
||||||
BufferSize int `json:"buffer_size"`
|
|
||||||
Format Format `json:"format" validate:"oneof=common combined json"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Stdout bool `json:"stdout"`
|
|
||||||
Filters Filters `json:"filters"`
|
|
||||||
Fields Fields `json:"fields"`
|
|
||||||
Retention *Retention `json:"retention"`
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -35,23 +52,57 @@ var (
|
||||||
FormatCombined Format = "combined"
|
FormatCombined Format = "combined"
|
||||||
FormatJSON Format = "json"
|
FormatJSON Format = "json"
|
||||||
|
|
||||||
AvailableFormats = []Format{FormatCommon, FormatCombined, FormatJSON}
|
ReqLoggerFormats = []Format{FormatCommon, FormatCombined, FormatJSON}
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultBufferSize = 64 * kilobyte // 64KB
|
const DefaultBufferSize = 64 * kilobyte // 64KB
|
||||||
|
|
||||||
func (cfg *Config) Validate() gperr.Error {
|
func (cfg *ConfigBase) Validate() gperr.Error {
|
||||||
if cfg.Path == "" && !cfg.Stdout {
|
if cfg.Path == "" && !cfg.Stdout {
|
||||||
return gperr.New("path or stdout is required")
|
return gperr.New("path or stdout is required")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultConfig() *Config {
|
func (cfg *ConfigBase) IO() (WriterWithName, error) {
|
||||||
|
ios := make([]WriterWithName, 0, 2)
|
||||||
|
if cfg.Stdout {
|
||||||
|
ios = append(ios, stdoutIO)
|
||||||
|
}
|
||||||
|
if cfg.Path != "" {
|
||||||
|
io, err := newFileIO(cfg.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ios = append(ios, io)
|
||||||
|
}
|
||||||
|
if len(ios) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return NewMultiWriter(ios...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *ACLLoggerConfig) ToConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
|
ConfigBase: &cfg.ConfigBase,
|
||||||
|
acl: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *RequestLoggerConfig) ToConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
ConfigBase: &cfg.ConfigBase,
|
||||||
|
req: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultRequestLoggerConfig() *RequestLoggerConfig {
|
||||||
|
return &RequestLoggerConfig{
|
||||||
|
ConfigBase: ConfigBase{
|
||||||
BufferSize: DefaultBufferSize,
|
BufferSize: DefaultBufferSize,
|
||||||
Format: FormatCombined,
|
|
||||||
Retention: &Retention{Days: 30},
|
Retention: &Retention{Days: 30},
|
||||||
|
},
|
||||||
|
Format: FormatCombined,
|
||||||
Fields: Fields{
|
Fields: Fields{
|
||||||
Headers: FieldConfig{
|
Headers: FieldConfig{
|
||||||
Default: FieldModeDrop,
|
Default: FieldModeDrop,
|
||||||
|
@ -66,6 +117,16 @@ func DefaultConfig() *Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func DefaultACLLoggerConfig() *ACLLoggerConfig {
|
||||||
utils.RegisterDefaultValueFactory(DefaultConfig)
|
return &ACLLoggerConfig{
|
||||||
|
ConfigBase: ConfigBase{
|
||||||
|
BufferSize: DefaultBufferSize,
|
||||||
|
Retention: &Retention{Days: 30},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
|
||||||
|
utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) {
|
||||||
parsed, err := docker.ParseLabels(labels)
|
parsed, err := docker.ParseLabels(labels)
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
|
|
||||||
var config Config
|
var config RequestLoggerConfig
|
||||||
err = utils.Deserialize(parsed, &config)
|
err = utils.Deserialize(parsed, &config)
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
// Cookie header should be removed,
|
// Cookie header should be removed,
|
||||||
// stored in JSONLogEntry.Cookies instead.
|
// stored in JSONLogEntry.Cookies instead.
|
||||||
func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
|
func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Headers.Default = FieldModeKeep
|
config.Fields.Headers.Default = FieldModeKeep
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
for k, v := range req.Header {
|
for k, v := range req.Header {
|
||||||
|
@ -29,7 +29,7 @@ func TestAccessLoggerJSONKeepHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
|
func TestAccessLoggerJSONDropHeaders(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Headers.Default = FieldModeDrop
|
config.Fields.Headers.Default = FieldModeDrop
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
for k := range req.Header {
|
for k := range req.Header {
|
||||||
|
@ -46,7 +46,7 @@ func TestAccessLoggerJSONDropHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
|
func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Headers.Default = FieldModeRedact
|
config.Fields.Headers.Default = FieldModeRedact
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
for k := range req.Header {
|
for k := range req.Header {
|
||||||
|
@ -57,7 +57,7 @@ func TestAccessLoggerJSONRedactHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSONKeepCookies(t *testing.T) {
|
func TestAccessLoggerJSONKeepCookies(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Headers.Default = FieldModeKeep
|
config.Fields.Headers.Default = FieldModeKeep
|
||||||
config.Fields.Cookies.Default = FieldModeKeep
|
config.Fields.Cookies.Default = FieldModeKeep
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
|
@ -67,7 +67,7 @@ func TestAccessLoggerJSONKeepCookies(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSONRedactCookies(t *testing.T) {
|
func TestAccessLoggerJSONRedactCookies(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Headers.Default = FieldModeKeep
|
config.Fields.Headers.Default = FieldModeKeep
|
||||||
config.Fields.Cookies.Default = FieldModeRedact
|
config.Fields.Cookies.Default = FieldModeRedact
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
|
@ -77,7 +77,7 @@ func TestAccessLoggerJSONRedactCookies(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSONDropQuery(t *testing.T) {
|
func TestAccessLoggerJSONDropQuery(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Query.Default = FieldModeDrop
|
config.Fields.Query.Default = FieldModeDrop
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
expect.Equal(t, entry.Query["foo"], nil)
|
expect.Equal(t, entry.Query["foo"], nil)
|
||||||
|
@ -85,7 +85,7 @@ func TestAccessLoggerJSONDropQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
|
func TestAccessLoggerJSONRedactQuery(t *testing.T) {
|
||||||
config := DefaultConfig()
|
config := DefaultRequestLoggerConfig()
|
||||||
config.Fields.Query.Default = FieldModeRedact
|
config.Fields.Query.Default = FieldModeRedact
|
||||||
entry := getJSONEntry(t, config)
|
entry := getJSONEntry(t, config)
|
||||||
expect.Equal(t, entry.Query["foo"], []string{RedactedValue})
|
expect.Equal(t, entry.Query["foo"], []string{RedactedValue})
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
*os.File
|
*os.File
|
||||||
sync.Mutex
|
|
||||||
|
|
||||||
// os.File.Name() may not equal to key of `openedFiles`.
|
// os.File.Name() may not equal to key of `openedFiles`.
|
||||||
// Store it for later delete from `openedFiles`.
|
// Store it for later delete from `openedFiles`.
|
||||||
|
@ -26,18 +25,18 @@ var (
|
||||||
openedFilesMu sync.Mutex
|
openedFilesMu sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func newFileIO(path string) (AccessLogIO, error) {
|
func newFileIO(path string) (WriterWithName, error) {
|
||||||
openedFilesMu.Lock()
|
openedFilesMu.Lock()
|
||||||
|
defer openedFilesMu.Unlock()
|
||||||
|
|
||||||
var file *File
|
var file *File
|
||||||
path = pathPkg.Clean(path)
|
path = pathPkg.Clean(path)
|
||||||
if opened, ok := openedFiles[path]; ok {
|
if opened, ok := openedFiles[path]; ok {
|
||||||
opened.refCount.Add()
|
opened.refCount.Add()
|
||||||
file = opened
|
return opened, nil
|
||||||
} else {
|
} else {
|
||||||
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
|
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openedFilesMu.Unlock()
|
|
||||||
return nil, fmt.Errorf("access log open error: %w", err)
|
return nil, fmt.Errorf("access log open error: %w", err)
|
||||||
}
|
}
|
||||||
file = &File{File: f, path: path, refCount: utils.NewRefCounter()}
|
file = &File{File: f, path: path, refCount: utils.NewRefCounter()}
|
||||||
|
@ -45,7 +44,6 @@ func newFileIO(path string) (AccessLogIO, error) {
|
||||||
go file.closeOnZero()
|
go file.closeOnZero()
|
||||||
}
|
}
|
||||||
|
|
||||||
openedFilesMu.Unlock()
|
|
||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,11 @@ import (
|
||||||
func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
|
func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultRequestLoggerConfig()
|
||||||
cfg.Path = "test.log"
|
cfg.Path = "test.log"
|
||||||
|
|
||||||
loggerCount := 10
|
loggerCount := 10
|
||||||
accessLogIOs := make([]AccessLogIO, loggerCount)
|
accessLogIOs := make([]WriterWithName, loggerCount)
|
||||||
|
|
||||||
// make test log file
|
// make test log file
|
||||||
file, err := os.Create(cfg.Path)
|
file, err := os.Create(cfg.Path)
|
||||||
|
@ -49,7 +49,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
|
||||||
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
|
func TestConcurrentAccessLoggerLogAndFlush(t *testing.T) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
|
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultRequestLoggerConfig()
|
||||||
cfg.BufferSize = 1024
|
cfg.BufferSize = 1024
|
||||||
parent := task.RootTask("test", false)
|
parent := task.RootTask("test", false)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
acl "github.com/yusing/go-proxy/internal/acl/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ type (
|
||||||
}
|
}
|
||||||
CombinedFormatter struct{ CommonFormatter }
|
CombinedFormatter struct{ CommonFormatter }
|
||||||
JSONFormatter struct{ CommonFormatter }
|
JSONFormatter struct{ CommonFormatter }
|
||||||
|
ACLLogFormatter struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||||
|
@ -56,7 +58,7 @@ func clientIP(req *http.Request) string {
|
||||||
return req.RemoteAddr
|
return req.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *CommonFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte {
|
func (f *CommonFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
|
||||||
query := f.cfg.Query.IterQuery(req.URL.Query())
|
query := f.cfg.Query.IterQuery(req.URL.Query())
|
||||||
|
|
||||||
line = append(line, req.Host...)
|
line = append(line, req.Host...)
|
||||||
|
@ -82,8 +84,8 @@ func (f *CommonFormatter) AppendLog(line []byte, req *http.Request, res *http.Re
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *CombinedFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte {
|
func (f *CombinedFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
|
||||||
line = f.CommonFormatter.AppendLog(line, req, res)
|
line = f.CommonFormatter.AppendRequestLog(line, req, res)
|
||||||
line = append(line, " \""...)
|
line = append(line, " \""...)
|
||||||
line = append(line, req.Referer()...)
|
line = append(line, req.Referer()...)
|
||||||
line = append(line, "\" \""...)
|
line = append(line, "\" \""...)
|
||||||
|
@ -118,14 +120,14 @@ func (z *zeroLogStringStringSliceMapMarshaler) MarshalZerologObject(e *zerolog.E
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *JSONFormatter) AppendLog(line []byte, req *http.Request, res *http.Response) []byte {
|
func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *http.Response) []byte {
|
||||||
query := f.cfg.Query.ZerologQuery(req.URL.Query())
|
query := f.cfg.Query.ZerologQuery(req.URL.Query())
|
||||||
headers := f.cfg.Headers.ZerologHeaders(req.Header)
|
headers := f.cfg.Headers.ZerologHeaders(req.Header)
|
||||||
cookies := f.cfg.Cookies.ZerologCookies(req.Cookies())
|
cookies := f.cfg.Cookies.ZerologCookies(req.Cookies())
|
||||||
contentType := res.Header.Get("Content-Type")
|
contentType := res.Header.Get("Content-Type")
|
||||||
|
|
||||||
writer := bytes.NewBuffer(line)
|
writer := bytes.NewBuffer(line)
|
||||||
logger := zerolog.New(writer).With().Logger()
|
logger := zerolog.New(writer)
|
||||||
event := logger.Info().
|
event := logger.Info().
|
||||||
Str("time", utils.TimeNow().Format(LogTimeFormat)).
|
Str("time", utils.TimeNow().Format(LogTimeFormat)).
|
||||||
Str("ip", clientIP(req)).
|
Str("ip", clientIP(req)).
|
||||||
|
@ -155,3 +157,23 @@ func (f *JSONFormatter) AppendLog(line []byte, req *http.Request, res *http.Resp
|
||||||
event.Send()
|
event.Send()
|
||||||
return writer.Bytes()
|
return writer.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f ACLLogFormatter) AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte {
|
||||||
|
writer := bytes.NewBuffer(line)
|
||||||
|
logger := zerolog.New(writer)
|
||||||
|
event := logger.Info().
|
||||||
|
Str("time", utils.TimeNow().Format(LogTimeFormat)).
|
||||||
|
Str("ip", info.Str)
|
||||||
|
if blocked {
|
||||||
|
event.Str("action", "block")
|
||||||
|
} else {
|
||||||
|
event.Str("action", "allow")
|
||||||
|
}
|
||||||
|
if info.City != nil {
|
||||||
|
event.Str("iso_code", info.City.Country.IsoCode)
|
||||||
|
event.Str("time_zone", info.City.Location.TimeZone)
|
||||||
|
}
|
||||||
|
// NOTE: zerolog will append a newline to the buffer
|
||||||
|
event.Send()
|
||||||
|
return writer.Bytes()
|
||||||
|
}
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
package accesslog
|
package accesslog
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type MultiWriter struct {
|
type MultiWriter struct {
|
||||||
writers []AccessLogIO
|
writers []WriterWithName
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMultiWriter(writers ...AccessLogIO) AccessLogIO {
|
type MultiWriterInterface interface {
|
||||||
|
Unwrap() []io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMultiWriter(writers ...WriterWithName) WriterWithName {
|
||||||
if len(writers) == 0 {
|
if len(writers) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -18,6 +25,14 @@ func NewMultiWriter(writers ...AccessLogIO) AccessLogIO {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *MultiWriter) Unwrap() []io.Writer {
|
||||||
|
writers := make([]io.Writer, len(w.writers))
|
||||||
|
for i, writer := range w.writers {
|
||||||
|
writers[i] = writer
|
||||||
|
}
|
||||||
|
return writers
|
||||||
|
}
|
||||||
|
|
||||||
func (w *MultiWriter) Write(p []byte) (n int, err error) {
|
func (w *MultiWriter) Write(p []byte) (n int, err error) {
|
||||||
for _, writer := range w.writers {
|
for _, writer := range w.writers {
|
||||||
writer.Write(p)
|
writer.Write(p)
|
||||||
|
@ -25,18 +40,6 @@ func (w *MultiWriter) Write(p []byte) (n int, err error) {
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *MultiWriter) Lock() {
|
|
||||||
for _, writer := range w.writers {
|
|
||||||
writer.Lock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MultiWriter) Unlock() {
|
|
||||||
for _, writer := range w.writers {
|
|
||||||
writer.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MultiWriter) Name() string {
|
func (w *MultiWriter) Name() string {
|
||||||
names := make([]string, len(w.writers))
|
names := make([]string, len(w.writers))
|
||||||
for i, writer := range w.writers {
|
for i, writer := range w.writers {
|
||||||
|
|
|
@ -12,9 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type supportRotate interface {
|
type supportRotate interface {
|
||||||
io.Reader
|
io.ReadSeeker
|
||||||
io.Writer
|
|
||||||
io.Seeker
|
|
||||||
io.ReaderAt
|
io.ReaderAt
|
||||||
io.WriterAt
|
io.WriterAt
|
||||||
Truncate(size int64) error
|
Truncate(size int64) error
|
||||||
|
@ -41,6 +39,14 @@ func (r *RotateResult) Print(logger *zerolog.Logger) {
|
||||||
Msg("log rotate result")
|
Msg("log rotate result")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RotateResult) Add(other *RotateResult) {
|
||||||
|
r.NumBytesRead += other.NumBytesRead
|
||||||
|
r.NumBytesKeep += other.NumBytesKeep
|
||||||
|
r.NumLinesRead += other.NumLinesRead
|
||||||
|
r.NumLinesKeep += other.NumLinesKeep
|
||||||
|
r.NumLinesInvalid += other.NumLinesInvalid
|
||||||
|
}
|
||||||
|
|
||||||
type lineInfo struct {
|
type lineInfo struct {
|
||||||
Pos int64 // Position from the start of the file
|
Pos int64 // Position from the start of the file
|
||||||
Size int64 // Size of this line
|
Size int64 // Size of this line
|
||||||
|
|
|
@ -53,11 +53,11 @@ func TestParseLogTime(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRotateKeepLast(t *testing.T) {
|
func TestRotateKeepLast(t *testing.T) {
|
||||||
for _, format := range AvailableFormats {
|
for _, format := range ReqLoggerFormats {
|
||||||
t.Run(string(format)+" keep last", func(t *testing.T) {
|
t.Run(string(format)+" keep last", func(t *testing.T) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
utils.MockTimeNow(testTime)
|
utils.MockTimeNow(testTime)
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: format,
|
Format: format,
|
||||||
})
|
})
|
||||||
expect.Nil(t, logger.Config().Retention)
|
expect.Nil(t, logger.Config().Retention)
|
||||||
|
@ -65,7 +65,7 @@ func TestRotateKeepLast(t *testing.T) {
|
||||||
for range 10 {
|
for range 10 {
|
||||||
logger.Log(req, resp)
|
logger.Log(req, resp)
|
||||||
}
|
}
|
||||||
expect.NoError(t, logger.Flush())
|
logger.Flush()
|
||||||
|
|
||||||
expect.Greater(t, file.Len(), int64(0))
|
expect.Greater(t, file.Len(), int64(0))
|
||||||
expect.Equal(t, file.NumLines(), 10)
|
expect.Equal(t, file.NumLines(), 10)
|
||||||
|
@ -85,7 +85,7 @@ func TestRotateKeepLast(t *testing.T) {
|
||||||
|
|
||||||
t.Run(string(format)+" keep days", func(t *testing.T) {
|
t.Run(string(format)+" keep days", func(t *testing.T) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: format,
|
Format: format,
|
||||||
})
|
})
|
||||||
expect.Nil(t, logger.Config().Retention)
|
expect.Nil(t, logger.Config().Retention)
|
||||||
|
@ -127,10 +127,10 @@ func TestRotateKeepLast(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRotateKeepFileSize(t *testing.T) {
|
func TestRotateKeepFileSize(t *testing.T) {
|
||||||
for _, format := range AvailableFormats {
|
for _, format := range ReqLoggerFormats {
|
||||||
t.Run(string(format)+" keep size no rotation", func(t *testing.T) {
|
t.Run(string(format)+" keep size no rotation", func(t *testing.T) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: format,
|
Format: format,
|
||||||
})
|
})
|
||||||
expect.Nil(t, logger.Config().Retention)
|
expect.Nil(t, logger.Config().Retention)
|
||||||
|
@ -160,7 +160,7 @@ func TestRotateKeepFileSize(t *testing.T) {
|
||||||
|
|
||||||
t.Run("keep size with rotation", func(t *testing.T) {
|
t.Run("keep size with rotation", func(t *testing.T) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: FormatJSON,
|
Format: FormatJSON,
|
||||||
})
|
})
|
||||||
expect.Nil(t, logger.Config().Retention)
|
expect.Nil(t, logger.Config().Retention)
|
||||||
|
@ -189,10 +189,10 @@ func TestRotateKeepFileSize(t *testing.T) {
|
||||||
|
|
||||||
// skipping invalid lines is not supported for keep file_size
|
// skipping invalid lines is not supported for keep file_size
|
||||||
func TestRotateSkipInvalidTime(t *testing.T) {
|
func TestRotateSkipInvalidTime(t *testing.T) {
|
||||||
for _, format := range AvailableFormats {
|
for _, format := range ReqLoggerFormats {
|
||||||
t.Run(string(format), func(t *testing.T) {
|
t.Run(string(format), func(t *testing.T) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: format,
|
Format: format,
|
||||||
})
|
})
|
||||||
expect.Nil(t, logger.Config().Retention)
|
expect.Nil(t, logger.Config().Retention)
|
||||||
|
@ -232,9 +232,11 @@ func BenchmarkRotate(b *testing.B) {
|
||||||
for _, retention := range tests {
|
for _, retention := range tests {
|
||||||
b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) {
|
b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: FormatJSON,
|
ConfigBase: ConfigBase{
|
||||||
Retention: retention,
|
Retention: retention,
|
||||||
|
},
|
||||||
|
Format: FormatJSON,
|
||||||
})
|
})
|
||||||
for i := range 100 {
|
for i := range 100 {
|
||||||
utils.MockTimeNow(testTime.AddDate(0, 0, -100+i+1))
|
utils.MockTimeNow(testTime.AddDate(0, 0, -100+i+1))
|
||||||
|
@ -263,9 +265,11 @@ func BenchmarkRotateWithInvalidTime(b *testing.B) {
|
||||||
for _, retention := range tests {
|
for _, retention := range tests {
|
||||||
b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) {
|
b.Run(fmt.Sprintf("retention_%s", retention), func(b *testing.B) {
|
||||||
file := NewMockFile()
|
file := NewMockFile()
|
||||||
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &Config{
|
logger := NewAccessLoggerWithIO(task.RootTask("test", false), file, &RequestLoggerConfig{
|
||||||
Format: FormatJSON,
|
ConfigBase: ConfigBase{
|
||||||
Retention: retention,
|
Retention: retention,
|
||||||
|
},
|
||||||
|
Format: FormatJSON,
|
||||||
})
|
})
|
||||||
for i := range 10000 {
|
for i := range 10000 {
|
||||||
utils.MockTimeNow(testTime.AddDate(0, 0, -10000+i+1))
|
utils.MockTimeNow(testTime.AddDate(0, 0, -10000+i+1))
|
||||||
|
|
|
@ -11,8 +11,6 @@ type StdoutLogger struct {
|
||||||
|
|
||||||
var stdoutIO = &StdoutLogger{os.Stdout}
|
var stdoutIO = &StdoutLogger{os.Stdout}
|
||||||
|
|
||||||
func (l *StdoutLogger) Lock() {}
|
|
||||||
func (l *StdoutLogger) Unlock() {}
|
|
||||||
func (l *StdoutLogger) Name() string {
|
func (l *StdoutLogger) Name() string {
|
||||||
return "stdout"
|
return "stdout"
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/http3"
|
"github.com/quic-go/quic-go/http3"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/yusing/go-proxy/internal/acl"
|
||||||
"github.com/yusing/go-proxy/internal/autocert"
|
"github.com/yusing/go-proxy/internal/autocert"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
@ -21,6 +22,7 @@ type Server struct {
|
||||||
http *http.Server
|
http *http.Server
|
||||||
https *http.Server
|
https *http.Server
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
acl *acl.Config
|
||||||
|
|
||||||
l zerolog.Logger
|
l zerolog.Logger
|
||||||
}
|
}
|
||||||
|
@ -31,6 +33,7 @@ type Options struct {
|
||||||
HTTPSAddr string
|
HTTPSAddr string
|
||||||
CertProvider *autocert.Provider
|
CertProvider *autocert.Provider
|
||||||
Handler http.Handler
|
Handler http.Handler
|
||||||
|
ACL *acl.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpServer interface {
|
type httpServer interface {
|
||||||
|
@ -76,6 +79,7 @@ func NewServer(opt Options) (s *Server) {
|
||||||
http: httpSer,
|
http: httpSer,
|
||||||
https: httpsSer,
|
https: httpsSer,
|
||||||
l: logger,
|
l: logger,
|
||||||
|
acl: opt.ACL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,16 +99,16 @@ func (s *Server) Start(parent task.Parent) {
|
||||||
Handler: s.https.Handler,
|
Handler: s.https.Handler,
|
||||||
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
|
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
|
||||||
}
|
}
|
||||||
Start(subtask, h3, &s.l)
|
Start(subtask, h3, s.acl, &s.l)
|
||||||
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
|
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
|
||||||
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
|
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
|
||||||
}
|
}
|
||||||
|
|
||||||
Start(subtask, s.http, &s.l)
|
Start(subtask, s.http, s.acl, &s.l)
|
||||||
Start(subtask, s.https, &s.l)
|
Start(subtask, s.https, s.acl, &s.l)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Logger) {
|
func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, logger *zerolog.Logger) {
|
||||||
if srv == nil {
|
if srv == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -130,6 +134,9 @@ func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Lo
|
||||||
if srv.TLSConfig != nil {
|
if srv.TLSConfig != nil {
|
||||||
l = tls.NewListener(l, srv.TLSConfig)
|
l = tls.NewListener(l, srv.TLSConfig)
|
||||||
}
|
}
|
||||||
|
if acl != nil {
|
||||||
|
l = acl.WrapTCP(l)
|
||||||
|
}
|
||||||
serveFunc = getServeFunc(l, srv.Serve)
|
serveFunc = getServeFunc(l, srv.Serve)
|
||||||
case *http3.Server:
|
case *http3.Server:
|
||||||
l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr)
|
l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr)
|
||||||
|
@ -137,6 +144,9 @@ func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Lo
|
||||||
HandleError(logger, err, "failed to listen on port")
|
HandleError(logger, err, "failed to listen on port")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if acl != nil {
|
||||||
|
l = acl.WrapUDP(l)
|
||||||
|
}
|
||||||
serveFunc = getServeFunc(l, srv.Serve)
|
serveFunc = getServeFunc(l, srv.Serve)
|
||||||
}
|
}
|
||||||
task.OnCancel("stop", func() {
|
task.OnCancel("stop", func() {
|
||||||
|
|
|
@ -48,7 +48,7 @@ type (
|
||||||
LoadBalance *loadbalance.Config `json:"load_balance,omitempty"`
|
LoadBalance *loadbalance.Config `json:"load_balance,omitempty"`
|
||||||
Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"`
|
Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"`
|
||||||
Homepage *homepage.ItemConfig `json:"homepage,omitempty"`
|
Homepage *homepage.ItemConfig `json:"homepage,omitempty"`
|
||||||
AccessLog *accesslog.Config `json:"access_log,omitempty"`
|
AccessLog *accesslog.RequestLoggerConfig `json:"access_log,omitempty"`
|
||||||
|
|
||||||
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
|
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue