breaking: move maxmind config to config.providers

- moved maxmind to separate module
- code refactored
- simplified test
This commit is contained in:
yusing 2025-05-03 20:58:09 +08:00
parent ac1470d81d
commit ccb4639f43
14 changed files with 314 additions and 369 deletions

View file

@ -2,17 +2,13 @@ package acl
import ( import (
"net" "net"
"sync"
"time" "time"
"github.com/oschwald/maxminddb-golang"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog"
acl "github.com/yusing/go-proxy/internal/acl/types"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/accesslog" "github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -20,43 +16,23 @@ import (
type Config struct { type Config struct {
Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow
AllowLocal *bool `json:"allow_local"` // default: true AllowLocal *bool `json:"allow_local"` // default: true
Allow []string `json:"allow"` Allow Matchers `json:"allow"`
Deny []string `json:"deny"` Deny Matchers `json:"deny"`
Log *accesslog.ACLLoggerConfig `json:"log"` Log *accesslog.ACLLoggerConfig `json:"log"`
MaxMind *MaxMindConfig `json:"maxmind" validate:"omitempty"`
config config
} }
type (
MaxMindDatabaseType string
MaxMindConfig struct {
AccountID string `json:"account_id" validate:"required"`
LicenseKey string `json:"license_key" validate:"required"`
Database MaxMindDatabaseType `json:"database" validate:"required,oneof=geolite geoip2"`
logger zerolog.Logger
lastUpdate time.Time
db struct {
*maxminddb.Reader
sync.RWMutex
}
}
)
type config struct { type config struct {
defaultAllow bool defaultAllow bool
allowLocal bool allowLocal bool
allow []matcher
deny []matcher
ipCache *xsync.MapOf[string, *checkCache] ipCache *xsync.MapOf[string, *checkCache]
logAllowed bool logAllowed bool
logger *accesslog.AccessLogger logger *accesslog.AccessLogger
} }
type checkCache struct { type checkCache struct {
*acl.IPInfo *maxmind.IPInfo
allow bool allow bool
created time.Time created time.Time
} }
@ -74,11 +50,6 @@ const (
ACLDeny = "deny" ACLDeny = "deny"
) )
const (
MaxMindGeoLite MaxMindDatabaseType = "geolite"
MaxMindGeoIP2 MaxMindDatabaseType = "geoip2"
)
func (c *Config) Validate() gperr.Error { func (c *Config) Validate() gperr.Error {
switch c.Default { switch c.Default {
case "", ACLAllow: case "", ACLAllow:
@ -95,55 +66,19 @@ func (c *Config) Validate() gperr.Error {
c.allowLocal = true c.allowLocal = true
} }
if c.MaxMind != nil {
c.MaxMind.logger = logging.With().Str("type", string(c.MaxMind.Database)).Logger()
}
if c.Log != nil { if c.Log != nil {
c.logAllowed = c.Log.LogAllowed c.logAllowed = c.Log.LogAllowed
} }
errs := gperr.NewBuilder("syntax error")
c.allow = make([]matcher, 0, len(c.Allow))
c.deny = make([]matcher, 0, len(c.Deny))
for _, s := range c.Allow {
m, err := c.parseMatcher(s)
if err != nil {
errs.Add(err.Subject(s))
continue
}
c.allow = append(c.allow, m)
}
for _, s := range c.Deny {
m, err := c.parseMatcher(s)
if err != nil {
errs.Add(err.Subject(s))
continue
}
c.deny = append(c.deny, m)
}
if errs.HasError() {
c.allow = nil
c.deny = nil
return errMatcherFormat.With(errs.Error())
}
c.ipCache = xsync.NewMapOf[string, *checkCache]() c.ipCache = xsync.NewMapOf[string, *checkCache]()
return nil return nil
} }
func (c *Config) Valid() bool { func (c *Config) Valid() bool {
return c != nil && (len(c.allow) > 0 || len(c.deny) > 0 || c.allowLocal) return c != nil && (len(c.Allow) > 0 || len(c.Deny) > 0 || c.allowLocal)
} }
func (c *Config) Start(parent *task.Task) gperr.Error { func (c *Config) Start(parent *task.Task) gperr.Error {
if c.MaxMind != nil {
if err := c.MaxMind.LoadMaxMindDB(parent); err != nil {
return err
}
}
if c.Log != nil { if c.Log != nil {
logger, err := accesslog.NewAccessLogger(parent, c.Log) logger, err := accesslog.NewAccessLogger(parent, c.Log)
if err != nil { if err != nil {
@ -154,9 +89,9 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
return nil return nil
} }
func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) { func (c *Config) cacheRecord(info *maxmind.IPInfo, allow bool) {
if common.ForceResolveCountry && info.City == nil { if common.ForceResolveCountry && info.City == nil {
c.MaxMind.lookupCity(info) maxmind.LookupCity(info)
} }
c.ipCache.Store(info.Str, &checkCache{ c.ipCache.Store(info.Str, &checkCache{
IPInfo: info, IPInfo: info,
@ -165,7 +100,7 @@ func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) {
}) })
} }
func (c *config) log(info *acl.IPInfo, allowed bool) { func (c *config) log(info *maxmind.IPInfo, allowed bool) {
if c.logger == nil { if c.logger == nil {
return return
} }
@ -186,7 +121,7 @@ func (c *Config) IPAllowed(ip net.IP) bool {
} }
if c.allowLocal && ip.IsPrivate() { if c.allowLocal && ip.IsPrivate() {
c.log(&acl.IPInfo{IP: ip, Str: ip.String()}, true) c.log(&maxmind.IPInfo{IP: ip, Str: ip.String()}, true)
return true return true
} }
@ -197,15 +132,15 @@ func (c *Config) IPAllowed(ip net.IP) bool {
return record.allow return record.allow
} }
ipAndStr := &acl.IPInfo{IP: ip, Str: ipStr} ipAndStr := &maxmind.IPInfo{IP: ip, Str: ipStr}
for _, m := range c.allow { for _, m := range c.Allow {
if m(ipAndStr) { if m(ipAndStr) {
c.log(ipAndStr, true) c.log(ipAndStr, true)
c.cacheRecord(ipAndStr, true) c.cacheRecord(ipAndStr, true)
return true return true
} }
} }
for _, m := range c.deny { for _, m := range c.Deny {
if m(ipAndStr) { if m(ipAndStr) {
c.log(ipAndStr, false) c.log(ipAndStr, false)
c.cacheRecord(ipAndStr, false) c.cacheRecord(ipAndStr, false)

View file

@ -4,11 +4,12 @@ import (
"net" "net"
"strings" "strings"
acl "github.com/yusing/go-proxy/internal/acl/types"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/maxmind"
) )
type matcher func(*acl.IPInfo) bool type Matcher func(*maxmind.IPInfo) bool
type Matchers []Matcher
const ( const (
MatcherTypeIP = "ip" MatcherTypeIP = "ip"
@ -32,7 +33,7 @@ var (
errMaxMindNotConfigured = gperr.New("MaxMind not configured") errMaxMindNotConfigured = gperr.New("MaxMind not configured")
) )
func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) { func ParseMatcher(s string) (Matcher, gperr.Error) {
parts := strings.Split(s, ":") parts := strings.Split(s, ":")
if len(parts) != 2 { if len(parts) != 2 {
return nil, errSyntax return nil, errSyntax
@ -52,35 +53,44 @@ func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) {
} }
return matchCIDR(net), nil return matchCIDR(net), nil
case MatcherTypeTimeZone: case MatcherTypeTimeZone:
if cfg.MaxMind == nil { if !maxmind.HasInstance() {
return nil, errMaxMindNotConfigured return nil, errMaxMindNotConfigured
} }
return cfg.MaxMind.matchTimeZone(parts[1]), nil return matchTimeZone(parts[1]), nil
case MatcherTypeCountry: case MatcherTypeCountry:
if cfg.MaxMind == nil { if !maxmind.HasInstance() {
return nil, errMaxMindNotConfigured return nil, errMaxMindNotConfigured
} }
return cfg.MaxMind.matchISOCode(parts[1]), nil return matchISOCode(parts[1]), nil
default: default:
return nil, errSyntax return nil, errSyntax
} }
} }
func matchIP(ip net.IP) matcher { func (matchers Matchers) Match(ip *maxmind.IPInfo) bool {
return func(ip2 *acl.IPInfo) bool { for _, m := range matchers {
if m(ip) {
return true
}
}
return false
}
func matchIP(ip net.IP) Matcher {
return func(ip2 *maxmind.IPInfo) bool {
return ip.Equal(ip2.IP) return ip.Equal(ip2.IP)
} }
} }
func matchCIDR(n *net.IPNet) matcher { func matchCIDR(n *net.IPNet) Matcher {
return func(ip *acl.IPInfo) bool { return func(ip *maxmind.IPInfo) bool {
return n.Contains(ip.IP) return n.Contains(ip.IP)
} }
} }
func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher { func matchTimeZone(tz string) Matcher {
return func(ip *acl.IPInfo) bool { return func(ip *maxmind.IPInfo) bool {
city, ok := cfg.lookupCity(ip) city, ok := maxmind.LookupCity(ip)
if !ok { if !ok {
return false return false
} }
@ -88,9 +98,9 @@ func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher {
} }
} }
func (cfg *MaxMindConfig) matchISOCode(iso string) matcher { func matchISOCode(iso string) Matcher {
return func(ip *acl.IPInfo) bool { return func(ip *maxmind.IPInfo) bool {
city, ok := cfg.lookupCity(ip) city, ok := maxmind.LookupCity(ip)
if !ok { if !ok {
return false return false
} }

View file

@ -1,223 +0,0 @@
package acl
import (
"archive/tar"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/oschwald/maxminddb-golang"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/task"
)
func Test_dbPath(t *testing.T) {
tmpDataDir := "/tmp/testdata"
oldDataDir := dataDir
dataDir = tmpDataDir
defer func() { dataDir = oldDataDir }()
tests := []struct {
name string
dbType MaxMindDatabaseType
want string
}{
{"GeoLite", MaxMindGeoLite, filepath.Join(tmpDataDir, "GeoLite2-City.mmdb")},
{"GeoIP2", MaxMindGeoIP2, filepath.Join(tmpDataDir, "GeoIP2-City.mmdb")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := dbPath(tt.dbType); got != tt.want {
t.Errorf("dbPath() = %v, want %v", got, tt.want)
}
})
}
}
func Test_dbURL(t *testing.T) {
tests := []struct {
name string
dbType MaxMindDatabaseType
want string
}{
{"GeoLite", MaxMindGeoLite, "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"},
{"GeoIP2", MaxMindGeoIP2, "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := dbURL(tt.dbType); got != tt.want {
t.Errorf("dbURL() = %v, want %v", got, tt.want)
}
})
}
}
// --- Helper for MaxMindConfig ---
type testLogger struct{ zerolog.Logger }
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
func Test_MaxMindConfig_newReq(t *testing.T) {
cfg := &MaxMindConfig{
AccountID: "testid",
LicenseKey: "testkey",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
// Patch httpClient to use httptest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" {
t.Errorf("basic auth not set correctly")
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
oldURL := dbURL
dbURL = func(MaxMindDatabaseType) string { return server.URL }
defer func() { dbURL = oldURL }()
resp, err := cfg.newReq(http.MethodGet)
if err != nil {
t.Fatalf("newReq() error = %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("unexpected status: %v", resp.StatusCode)
}
}
func Test_MaxMindConfig_checkUpdate(t *testing.T) {
cfg := &MaxMindConfig{
AccountID: "id",
LicenseKey: "key",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
lastMod := time.Now().UTC().Format(http.TimeFormat)
buildTime := time.Now().Add(-time.Hour)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Last-Modified", lastMod)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
oldURL := dbURL
dbURL = func(MaxMindDatabaseType) string { return server.URL }
defer func() { dbURL = oldURL }()
latest, err := cfg.checkLastest()
if err != nil {
t.Fatalf("checkUpdate() error = %v", err)
}
if latest.Equal(buildTime) {
t.Errorf("expected update needed")
}
}
type fakeReadCloser struct {
firstRead bool
closed bool
}
func (c *fakeReadCloser) Read(p []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
return strings.NewReader("FAKEMMDB").Read(p)
}
return 0, io.EOF
}
func (c *fakeReadCloser) Close() error {
c.closed = true
return nil
}
func Test_MaxMindConfig_download(t *testing.T) {
cfg := &MaxMindConfig{
AccountID: "id",
LicenseKey: "key",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gz := gzip.NewWriter(w)
t := tar.NewWriter(gz)
t.WriteHeader(&tar.Header{
Name: dbFilename(MaxMindGeoLite),
})
t.Write([]byte("1234"))
t.Close()
gz.Close()
}))
defer server.Close()
oldURL := dbURL
dbURL = func(MaxMindDatabaseType) string { return server.URL }
defer func() { dbURL = oldURL }()
tmpDir := t.TempDir()
oldDataDir := dataDir
dataDir = tmpDir
defer func() { dataDir = oldDataDir }()
// Patch maxminddb.Open to always succeed
origOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil
}
defer func() { maxmindDBOpen = origOpen }()
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("newReq() error = %v", err)
}
rw := httptest.NewRecorder()
oldNewReq := newReq
newReq = func(cfg *MaxMindConfig, method string) (*http.Response, error) {
server.Config.Handler.ServeHTTP(rw, req)
return rw.Result(), nil
}
defer func() { newReq = oldNewReq }()
err = cfg.download()
if err != nil {
t.Fatalf("download() error = %v", err)
}
if cfg.db.Reader == nil {
t.Error("expected db instance")
}
}
func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
// This test should cover both the path where DB exists and where it does not
// For brevity, only the non-existing path is tested here
cfg := &MaxMindConfig{
AccountID: "id",
LicenseKey: "key",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
oldOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil
}
defer func() { maxmindDBOpen = oldOpen }()
oldDBPath := dbPath
dbPath = func(MaxMindDatabaseType) string { return filepath.Join(t.TempDir(), "maxmind.mmdb") }
defer func() { dbPath = oldDBPath }()
task := task.RootTask("test")
defer task.Finish(nil)
err := cfg.LoadMaxMindDB(task)
if err != nil {
t.Errorf("loadMaxMindDB() error = %v", err)
}
}

View file

@ -17,6 +17,7 @@ import (
"github.com/yusing/go-proxy/internal/entrypoint" "github.com/yusing/go-proxy/internal/entrypoint"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox" "github.com/yusing/go-proxy/internal/proxmox"
@ -230,6 +231,7 @@ func (cfg *Config) load() gperr.Error {
errs := gperr.NewBuilder(errMsg) errs := gperr.NewBuilder(errMsg)
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
errs.Add(cfg.initMaxMind(model.Providers.MaxMind))
cfg.initNotification(model.Providers.Notification) cfg.initNotification(model.Providers.Notification)
errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.initAutoCert(model.AutoCert))
errs.Add(cfg.initProxmox(model.Providers.Proxmox)) errs.Add(cfg.initProxmox(model.Providers.Proxmox))
@ -262,6 +264,13 @@ func (cfg *Config) load() gperr.Error {
return nil return nil
} }
func (cfg *Config) initMaxMind(maxmindCfg *maxmind.Config) gperr.Error {
if maxmindCfg != nil {
return maxmind.SetInstance(cfg.task, maxmindCfg)
}
return nil
}
func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) { func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
if len(notifCfg) == 0 { if len(notifCfg) == 0 {
return return

View file

@ -11,6 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging/accesslog" "github.com/yusing/go-proxy/internal/logging/accesslog"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox" "github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
@ -32,6 +33,7 @@ type (
Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty"` Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty"`
Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"` Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"`
Proxmox []proxmox.Config `json:"proxmox" yaml:"proxmox,omitempty"` Proxmox []proxmox.Config `json:"proxmox" yaml:"proxmox,omitempty"`
MaxMind *maxmind.Config `json:"maxmind" yaml:"maxmind,omitempty"`
} }
Entrypoint struct { Entrypoint struct {
Middlewares []map[string]any `json:"middlewares"` Middlewares []map[string]any `json:"middlewares"`

View file

@ -9,9 +9,9 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
acl "github.com/yusing/go-proxy/internal/acl/types"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/utils/synk" "github.com/yusing/go-proxy/internal/utils/synk"
@ -61,7 +61,7 @@ type (
} }
ACLFormatter interface { ACLFormatter interface {
// AppendACLLog appends a log line to line with or without a trailing newline // AppendACLLog appends a log line to line with or without a trailing newline
AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte AppendACLLog(line []byte, info *maxmind.IPInfo, blocked bool) []byte
} }
) )
@ -179,7 +179,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()}) l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
} }
func (l *AccessLogger) LogACL(info *acl.IPInfo, blocked bool) { func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := l.lineBufPool.Get() line := l.lineBufPool.Get()
defer l.lineBufPool.Put(line) defer l.lineBufPool.Put(line)
line = l.ACLFormatter.AppendACLLog(line, info, blocked) line = l.ACLFormatter.AppendACLLog(line, info, blocked)

View file

@ -8,7 +8,7 @@ import (
"strconv" "strconv"
"github.com/rs/zerolog" "github.com/rs/zerolog"
acl "github.com/yusing/go-proxy/internal/acl/types" maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -158,7 +158,7 @@ func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *ht
return writer.Bytes() return writer.Bytes()
} }
func (f ACLLogFormatter) AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte { func (f ACLLogFormatter) AppendACLLog(line []byte, info *maxmind.IPInfo, blocked bool) []byte {
writer := bytes.NewBuffer(line) writer := bytes.NewBuffer(line)
logger := zerolog.New(writer) logger := zerolog.New(writer)
event := logger.Info(). event := logger.Info().

View file

@ -1,13 +1,12 @@
package acl package maxmind
import ( import (
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
acl "github.com/yusing/go-proxy/internal/acl/types"
) )
var cityCache = xsync.NewMapOf[string, *acl.City]() var cityCache = xsync.NewMapOf[string, *City]()
func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) { func (cfg *MaxMind) lookupCity(ip *IPInfo) (*City, bool) {
if ip.City != nil { if ip.City != nil {
return ip.City, true return ip.City, true
} }
@ -25,7 +24,7 @@ func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) {
cfg.db.RLock() cfg.db.RLock()
defer cfg.db.RUnlock() defer cfg.db.RUnlock()
city = new(acl.City) city = new(City)
err := cfg.db.Lookup(ip.IP, city) err := cfg.db.Lookup(ip.IP, city)
if err != nil { if err != nil {
return nil, false return nil, false

View file

@ -0,0 +1,31 @@
package maxmind
import (
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/task"
)
var instance *MaxMind
func SetInstance(parent task.Parent, cfg *Config) gperr.Error {
newInstance := &MaxMind{Config: cfg}
if err := newInstance.LoadMaxMindDB(parent); err != nil {
return err
}
if instance != nil {
instance.task.Finish("updated")
}
instance = newInstance
return nil
}
func HasInstance() bool {
return instance != nil
}
func LookupCity(ip *IPInfo) (*City, bool) {
if instance == nil {
return nil, false
}
return instance.lookupCity(ip)
}

View file

@ -1,4 +1,4 @@
package acl package maxmind
import ( import (
"archive/tar" "archive/tar"
@ -9,14 +9,32 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"time" "time"
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
) )
type MaxMind struct {
*Config
lastUpdate time.Time
task *task.Task
db struct {
*maxminddb.Reader
sync.RWMutex
}
}
type (
Config = maxmind.Config
IPInfo = maxmind.IPInfo
City = maxmind.City
)
var ( var (
updateInterval = 24 * time.Hour updateInterval = 24 * time.Hour
httpClient = &http.Client{ httpClient = &http.Client{
@ -26,33 +44,34 @@ var (
ErrDownloadFailure = gperr.New("download failure") ErrDownloadFailure = gperr.New("download failure")
) )
func dbPathImpl(dbType MaxMindDatabaseType) string { func (cfg *MaxMind) dbPath() string {
if dbType == MaxMindGeoLite { if cfg.Database == maxmind.MaxMindGeoLite {
return filepath.Join(dataDir, "GeoLite2-City.mmdb") return filepath.Join(dataDir, "GeoLite2-City.mmdb")
} }
return filepath.Join(dataDir, "GeoIP2-City.mmdb") return filepath.Join(dataDir, "GeoIP2-City.mmdb")
} }
func dbURLimpl(dbType MaxMindDatabaseType) string { func (cfg *MaxMind) dbURL() string {
if dbType == MaxMindGeoLite { if cfg.Database == maxmind.MaxMindGeoLite {
return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz" return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"
} }
return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz" return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"
} }
func dbFilename(dbType MaxMindDatabaseType) string { func (cfg *MaxMind) dbFilename() string {
if dbType == MaxMindGeoLite { if cfg.Database == maxmind.MaxMindGeoLite {
return "GeoLite2-City.mmdb" return "GeoLite2-City.mmdb"
} }
return "GeoIP2-City.mmdb" return "GeoIP2-City.mmdb"
} }
func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error { func (cfg *MaxMind) LoadMaxMindDB(parent task.Parent) gperr.Error {
if cfg.Database == "" { if cfg.Database == "" {
return nil return nil
} }
path := dbPath(cfg.Database) cfg.task = parent.Subtask("maxmind_db", true)
path := dbPath(cfg)
reader, err := maxmindDBOpen(path) reader, err := maxmindDBOpen(path)
valid := true valid := true
if err != nil { if err != nil {
@ -69,32 +88,32 @@ func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error {
} }
if !valid { if !valid {
cfg.logger.Info().Msg("MaxMind DB not found/invalid, downloading...") cfg.Logger().Info().Msg("MaxMind DB not found/invalid, downloading...")
if err = cfg.download(); err != nil { if err = cfg.download(); err != nil {
return ErrDownloadFailure.With(err) return ErrDownloadFailure.With(err)
} }
} else { } else {
cfg.logger.Info().Msg("MaxMind DB loaded") cfg.Logger().Info().Msg("MaxMind DB loaded")
cfg.db.Reader = reader cfg.db.Reader = reader
go cfg.scheduleUpdate(parent) go cfg.scheduleUpdate(cfg.task)
} }
return nil return nil
} }
func (cfg *MaxMindConfig) loadLastUpdate() { func (cfg *MaxMind) loadLastUpdate() {
f, err := os.Stat(dbPath(cfg.Database)) f, err := os.Stat(cfg.dbPath())
if err != nil { if err != nil {
return return
} }
cfg.lastUpdate = f.ModTime() cfg.lastUpdate = f.ModTime()
} }
func (cfg *MaxMindConfig) setLastUpdate(t time.Time) { func (cfg *MaxMind) setLastUpdate(t time.Time) {
cfg.lastUpdate = t cfg.lastUpdate = t
_ = os.Chtimes(dbPath(cfg.Database), t, t) _ = os.Chtimes(cfg.dbPath(), t, t)
} }
func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) { func (cfg *MaxMind) scheduleUpdate(parent task.Parent) {
task := parent.Subtask("schedule_update", true) task := parent.Subtask("schedule_update", true)
ticker := time.NewTicker(updateInterval) ticker := time.NewTicker(updateInterval)
@ -119,45 +138,45 @@ func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) {
} }
} }
func (cfg *MaxMindConfig) update() { func (cfg *MaxMind) update() {
// check for update // check for update
cfg.logger.Info().Msg("checking for MaxMind DB update...") cfg.Logger().Info().Msg("checking for MaxMind DB update...")
remoteLastModified, err := cfg.checkLastest() remoteLastModified, err := cfg.checkLastest()
if err != nil { if err != nil {
cfg.logger.Err(err).Msg("failed to check MaxMind DB update") cfg.Logger().Err(err).Msg("failed to check MaxMind DB update")
return return
} }
if remoteLastModified.Equal(cfg.lastUpdate) { if remoteLastModified.Equal(cfg.lastUpdate) {
cfg.logger.Info().Msg("MaxMind DB is up to date") cfg.Logger().Info().Msg("MaxMind DB is up to date")
return return
} }
cfg.logger.Info(). cfg.Logger().Info().
Time("latest", remoteLastModified.Local()). Time("latest", remoteLastModified.Local()).
Time("current", cfg.lastUpdate). Time("current", cfg.lastUpdate).
Msg("MaxMind DB update available") Msg("MaxMind DB update available")
if err = cfg.download(); err != nil { if err = cfg.download(); err != nil {
cfg.logger.Err(err).Msg("failed to update MaxMind DB") cfg.Logger().Err(err).Msg("failed to update MaxMind DB")
return return
} }
cfg.logger.Info().Msg("MaxMind DB updated") cfg.Logger().Info().Msg("MaxMind DB updated")
} }
func (cfg *MaxMindConfig) newReq(method string) (*http.Response, error) { func (cfg *MaxMind) doReq(method string) (*http.Response, error) {
req, err := http.NewRequest(method, dbURL(cfg.Database), nil) req, err := http.NewRequest(method, cfg.dbURL(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey) req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey)
resp, err := httpClient.Do(req) resp, err := doReq(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return resp, nil return resp, nil
} }
func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) { func (cfg *MaxMind) checkLastest() (lastModifiedT *time.Time, err error) {
resp, err := newReq(cfg, http.MethodHead) resp, err := cfg.doReq(http.MethodHead)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -169,21 +188,21 @@ func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) {
lastModified := resp.Header.Get("Last-Modified") lastModified := resp.Header.Get("Last-Modified")
if lastModified == "" { if lastModified == "" {
cfg.logger.Warn().Msg("MaxMind responded no last modified time, update skipped") cfg.Logger().Warn().Msg("MaxMind responded no last modified time, update skipped")
return nil, nil return nil, nil
} }
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified) lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified)
if err != nil { if err != nil {
cfg.logger.Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped") cfg.Logger().Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped")
return nil, err return nil, err
} }
return &lastModifiedTime, nil return &lastModifiedTime, nil
} }
func (cfg *MaxMindConfig) download() error { func (cfg *MaxMind) download() error {
resp, err := newReq(cfg, http.MethodGet) resp, err := cfg.doReq(http.MethodGet)
if err != nil { if err != nil {
return err return err
} }
@ -193,7 +212,7 @@ func (cfg *MaxMindConfig) download() error {
return fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode) return fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode)
} }
dbFile := dbPath(cfg.Database) dbFile := dbPath(cfg)
tmpGZPath := dbFile + "-tmp.tar.gz" tmpGZPath := dbFile + "-tmp.tar.gz"
tmpDBPath := dbFile + "-tmp" tmpDBPath := dbFile + "-tmp"
@ -208,7 +227,7 @@ func (cfg *MaxMindConfig) download() error {
_ = os.Remove(tmpGZPath) _ = os.Remove(tmpGZPath)
}() }()
cfg.logger.Info().Msg("MaxMind DB downloading...") cfg.Logger().Info().Msg("MaxMind DB downloading...")
_, err = io.Copy(tmpGZFile, resp.Body) _, err = io.Copy(tmpGZFile, resp.Body)
if err != nil { if err != nil {
@ -220,7 +239,7 @@ func (cfg *MaxMindConfig) download() error {
} }
// extract .tar.gz and to database // extract .tar.gz and to database
err = extractFileFromTarGz(tmpGZFile, dbFilename(cfg.Database), tmpDBPath) err = extractFileFromTarGz(tmpGZFile, cfg.dbFilename(), tmpDBPath)
if err != nil { if err != nil {
return gperr.New("failed to extract database from archive").With(err) return gperr.New("failed to extract database from archive").With(err)
@ -255,7 +274,7 @@ func (cfg *MaxMindConfig) download() error {
cfg.setLastUpdate(lastModifiedTime) cfg.setLastUpdate(lastModifiedTime)
} }
cfg.logger.Info().Msg("MaxMind DB downloaded") cfg.Logger().Info().Msg("MaxMind DB downloaded")
return nil return nil
} }
@ -296,8 +315,7 @@ func extractFileFromTarGz(tarGzFile *os.File, targetFilename, destPath string) e
var ( var (
dataDir = common.DataDir dataDir = common.DataDir
dbURL = dbURLimpl dbPath = (*MaxMind).dbPath
dbPath = dbPathImpl doReq = httpClient.Do
maxmindDBOpen = maxminddb.Open maxmindDBOpen = maxminddb.Open
newReq = (*MaxMindConfig).newReq
) )

View file

@ -0,0 +1,131 @@
package maxmind
import (
"archive/tar"
"compress/gzip"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/oschwald/maxminddb-golang"
"github.com/rs/zerolog"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task"
)
// --- Helper for MaxMindConfig ---
type testLogger struct{ zerolog.Logger }
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
func testCfg() *MaxMind {
return &MaxMind{
Config: &Config{
AccountID: "testid",
LicenseKey: "testkey",
Database: maxmind.MaxMindGeoLite,
},
}
}
var testLastMod = time.Now().UTC()
func testDoReq(cfg *MaxMind, w http.ResponseWriter, r *http.Request) {
if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat))
gz := gzip.NewWriter(w)
t := tar.NewWriter(gz)
t.WriteHeader(&tar.Header{
Name: cfg.dbFilename(),
})
t.Write([]byte("1234"))
t.Close()
gz.Close()
w.WriteHeader(http.StatusOK)
}
func mockDoReq(cfg *MaxMind, t *testing.T) {
rw := httptest.NewRecorder()
oldDoReq := doReq
doReq = func(req *http.Request) (*http.Response, error) {
testDoReq(cfg, rw, req)
return rw.Result(), nil
}
t.Cleanup(func() { doReq = oldDoReq })
}
func mockDataDir(t *testing.T) {
oldDataDir := dataDir
dataDir = t.TempDir()
t.Cleanup(func() { dataDir = oldDataDir })
}
func mockMaxMindDBOpen(t *testing.T) {
oldMaxMindDBOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil
}
t.Cleanup(func() { maxmindDBOpen = oldMaxMindDBOpen })
}
func Test_MaxMindConfig_doReq(t *testing.T) {
cfg := testCfg()
mockDoReq(cfg, t)
resp, err := cfg.doReq(http.MethodGet)
if err != nil {
t.Fatalf("newReq() error = %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("unexpected status: %v", resp.StatusCode)
}
}
func Test_MaxMindConfig_checkLatest(t *testing.T) {
cfg := testCfg()
mockDoReq(cfg, t)
latest, err := cfg.checkLastest()
if err != nil {
t.Fatalf("checkLatest() error = %v", err)
}
if latest.Equal(testLastMod) {
t.Errorf("expected latest equal to testLastMod")
}
}
func Test_MaxMindConfig_download(t *testing.T) {
cfg := testCfg()
mockDataDir(t)
mockMaxMindDBOpen(t)
mockDoReq(cfg, t)
err := cfg.download()
if err != nil {
t.Fatalf("download() error = %v", err)
}
if cfg.db.Reader == nil {
t.Error("expected db instance")
}
}
func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
cfg := testCfg()
mockDataDir(t)
mockMaxMindDBOpen(t)
task := task.RootTask("test")
defer task.Finish(nil)
err := cfg.LoadMaxMindDB(task)
if err != nil {
t.Errorf("loadMaxMindDB() error = %v", err)
}
if cfg.db.Reader == nil {
t.Error("expected db instance")
}
}

View file

@ -1,4 +1,4 @@
package acl package maxmind
type City struct { type City struct {
Location struct { Location struct {

View file

@ -0,0 +1,33 @@
package maxmind
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
)
type (
DatabaseType string
Config struct {
AccountID string `json:"account_id" validate:"required"`
LicenseKey string `json:"license_key" validate:"required"`
Database DatabaseType `json:"database" validate:"omitempty,oneof=geolite geoip2"`
}
)
const (
MaxMindGeoLite DatabaseType = "geolite"
MaxMindGeoIP2 DatabaseType = "geoip2"
)
func (cfg *Config) Validate() gperr.Error {
if cfg.Database == "" {
cfg.Database = MaxMindGeoLite
}
return nil
}
func (cfg *Config) Logger() *zerolog.Logger {
l := logging.With().Str("database", string(cfg.Database)).Logger()
return &l
}

View file

@ -1,4 +1,4 @@
package acl package maxmind
import "net" import "net"