From ccb4639f4320362a2037e3bfe6ffb54b28b78e11 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 3 May 2025 20:58:09 +0800 Subject: [PATCH] breaking: move maxmind config to config.providers - moved maxmind to separate module - code refactored - simplified test --- internal/acl/config.go | 89 +------- internal/acl/matcher.go | 44 ++-- internal/acl/maxmind_test.go | 223 ------------------- internal/config/config.go | 9 + internal/config/types/config.go | 2 + internal/logging/accesslog/access_logger.go | 6 +- internal/logging/accesslog/formatter.go | 4 +- internal/{acl => maxmind}/city_cache.go | 9 +- internal/maxmind/instance.go | 31 +++ internal/{acl => maxmind}/maxmind.go | 98 ++++---- internal/maxmind/maxmind_test.go | 131 +++++++++++ internal/{acl => maxmind}/types/city_info.go | 2 +- internal/maxmind/types/config.go | 33 +++ internal/{acl => maxmind}/types/ip_info.go | 2 +- 14 files changed, 314 insertions(+), 369 deletions(-) delete mode 100644 internal/acl/maxmind_test.go rename internal/{acl => maxmind}/city_cache.go (66%) create mode 100644 internal/maxmind/instance.go rename internal/{acl => maxmind}/maxmind.go (69%) create mode 100644 internal/maxmind/maxmind_test.go rename internal/{acl => maxmind}/types/city_info.go (92%) create mode 100644 internal/maxmind/types/config.go rename internal/{acl => maxmind}/types/ip_info.go (82%) diff --git a/internal/acl/config.go b/internal/acl/config.go index 1bbe147..f058013 100644 --- a/internal/acl/config.go +++ b/internal/acl/config.go @@ -2,17 +2,13 @@ package acl import ( "net" - "sync" "time" - "github.com/oschwald/maxminddb-golang" "github.com/puzpuzpuz/xsync/v3" - "github.com/rs/zerolog" - acl "github.com/yusing/go-proxy/internal/acl/types" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/accesslog" + "github.com/yusing/go-proxy/internal/maxmind" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" ) @@ -20,43 +16,23 @@ import ( type Config struct { Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow AllowLocal *bool `json:"allow_local"` // default: true - Allow []string `json:"allow"` - Deny []string `json:"deny"` + Allow Matchers `json:"allow"` + Deny Matchers `json:"deny"` Log *accesslog.ACLLoggerConfig `json:"log"` - MaxMind *MaxMindConfig `json:"maxmind" validate:"omitempty"` - config } -type ( - MaxMindDatabaseType string - MaxMindConfig struct { - AccountID string `json:"account_id" validate:"required"` - LicenseKey string `json:"license_key" validate:"required"` - Database MaxMindDatabaseType `json:"database" validate:"required,oneof=geolite geoip2"` - - logger zerolog.Logger - lastUpdate time.Time - db struct { - *maxminddb.Reader - sync.RWMutex - } - } -) - type config struct { defaultAllow bool allowLocal bool - allow []matcher - deny []matcher ipCache *xsync.MapOf[string, *checkCache] logAllowed bool logger *accesslog.AccessLogger } type checkCache struct { - *acl.IPInfo + *maxmind.IPInfo allow bool created time.Time } @@ -74,11 +50,6 @@ const ( ACLDeny = "deny" ) -const ( - MaxMindGeoLite MaxMindDatabaseType = "geolite" - MaxMindGeoIP2 MaxMindDatabaseType = "geoip2" -) - func (c *Config) Validate() gperr.Error { switch c.Default { case "", ACLAllow: @@ -95,55 +66,19 @@ func (c *Config) Validate() gperr.Error { c.allowLocal = true } - if c.MaxMind != nil { - c.MaxMind.logger = logging.With().Str("type", string(c.MaxMind.Database)).Logger() - } - if c.Log != nil { c.logAllowed = c.Log.LogAllowed } - errs := gperr.NewBuilder("syntax error") - c.allow = make([]matcher, 0, len(c.Allow)) - c.deny = make([]matcher, 0, len(c.Deny)) - - for _, s := range c.Allow { - m, err := c.parseMatcher(s) - if err != nil { - errs.Add(err.Subject(s)) - continue - } - c.allow = append(c.allow, m) - } - for _, s := range c.Deny { - m, err := c.parseMatcher(s) - if err != nil { - errs.Add(err.Subject(s)) - continue - } - c.deny = append(c.deny, m) - } - - if errs.HasError() { - c.allow = nil - c.deny = nil - return errMatcherFormat.With(errs.Error()) - } - c.ipCache = xsync.NewMapOf[string, *checkCache]() return nil } func (c *Config) Valid() bool { - return c != nil && (len(c.allow) > 0 || len(c.deny) > 0 || c.allowLocal) + return c != nil && (len(c.Allow) > 0 || len(c.Deny) > 0 || c.allowLocal) } func (c *Config) Start(parent *task.Task) gperr.Error { - if c.MaxMind != nil { - if err := c.MaxMind.LoadMaxMindDB(parent); err != nil { - return err - } - } if c.Log != nil { logger, err := accesslog.NewAccessLogger(parent, c.Log) if err != nil { @@ -154,9 +89,9 @@ func (c *Config) Start(parent *task.Task) gperr.Error { return nil } -func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) { +func (c *Config) cacheRecord(info *maxmind.IPInfo, allow bool) { if common.ForceResolveCountry && info.City == nil { - c.MaxMind.lookupCity(info) + maxmind.LookupCity(info) } c.ipCache.Store(info.Str, &checkCache{ IPInfo: info, @@ -165,7 +100,7 @@ func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) { }) } -func (c *config) log(info *acl.IPInfo, allowed bool) { +func (c *config) log(info *maxmind.IPInfo, allowed bool) { if c.logger == nil { return } @@ -186,7 +121,7 @@ func (c *Config) IPAllowed(ip net.IP) bool { } if c.allowLocal && ip.IsPrivate() { - c.log(&acl.IPInfo{IP: ip, Str: ip.String()}, true) + c.log(&maxmind.IPInfo{IP: ip, Str: ip.String()}, true) return true } @@ -197,15 +132,15 @@ func (c *Config) IPAllowed(ip net.IP) bool { return record.allow } - ipAndStr := &acl.IPInfo{IP: ip, Str: ipStr} - for _, m := range c.allow { + ipAndStr := &maxmind.IPInfo{IP: ip, Str: ipStr} + for _, m := range c.Allow { if m(ipAndStr) { c.log(ipAndStr, true) c.cacheRecord(ipAndStr, true) return true } } - for _, m := range c.deny { + for _, m := range c.Deny { if m(ipAndStr) { c.log(ipAndStr, false) c.cacheRecord(ipAndStr, false) diff --git a/internal/acl/matcher.go b/internal/acl/matcher.go index ae0f385..27abe0d 100644 --- a/internal/acl/matcher.go +++ b/internal/acl/matcher.go @@ -4,11 +4,12 @@ import ( "net" "strings" - acl "github.com/yusing/go-proxy/internal/acl/types" "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/maxmind" ) -type matcher func(*acl.IPInfo) bool +type Matcher func(*maxmind.IPInfo) bool +type Matchers []Matcher const ( MatcherTypeIP = "ip" @@ -32,7 +33,7 @@ var ( errMaxMindNotConfigured = gperr.New("MaxMind not configured") ) -func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) { +func ParseMatcher(s string) (Matcher, gperr.Error) { parts := strings.Split(s, ":") if len(parts) != 2 { return nil, errSyntax @@ -52,35 +53,44 @@ func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) { } return matchCIDR(net), nil case MatcherTypeTimeZone: - if cfg.MaxMind == nil { + if !maxmind.HasInstance() { return nil, errMaxMindNotConfigured } - return cfg.MaxMind.matchTimeZone(parts[1]), nil + return matchTimeZone(parts[1]), nil case MatcherTypeCountry: - if cfg.MaxMind == nil { + if !maxmind.HasInstance() { return nil, errMaxMindNotConfigured } - return cfg.MaxMind.matchISOCode(parts[1]), nil + return matchISOCode(parts[1]), nil default: return nil, errSyntax } } -func matchIP(ip net.IP) matcher { - return func(ip2 *acl.IPInfo) bool { +func (matchers Matchers) Match(ip *maxmind.IPInfo) bool { + for _, m := range matchers { + if m(ip) { + return true + } + } + return false +} + +func matchIP(ip net.IP) Matcher { + return func(ip2 *maxmind.IPInfo) bool { return ip.Equal(ip2.IP) } } -func matchCIDR(n *net.IPNet) matcher { - return func(ip *acl.IPInfo) bool { +func matchCIDR(n *net.IPNet) Matcher { + return func(ip *maxmind.IPInfo) bool { return n.Contains(ip.IP) } } -func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher { - return func(ip *acl.IPInfo) bool { - city, ok := cfg.lookupCity(ip) +func matchTimeZone(tz string) Matcher { + return func(ip *maxmind.IPInfo) bool { + city, ok := maxmind.LookupCity(ip) if !ok { return false } @@ -88,9 +98,9 @@ func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher { } } -func (cfg *MaxMindConfig) matchISOCode(iso string) matcher { - return func(ip *acl.IPInfo) bool { - city, ok := cfg.lookupCity(ip) +func matchISOCode(iso string) Matcher { + return func(ip *maxmind.IPInfo) bool { + city, ok := maxmind.LookupCity(ip) if !ok { return false } diff --git a/internal/acl/maxmind_test.go b/internal/acl/maxmind_test.go deleted file mode 100644 index 6c0c0c9..0000000 --- a/internal/acl/maxmind_test.go +++ /dev/null @@ -1,223 +0,0 @@ -package acl - -import ( - "archive/tar" - "compress/gzip" - "io" - "net/http" - "net/http/httptest" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/oschwald/maxminddb-golang" - "github.com/rs/zerolog" - "github.com/yusing/go-proxy/internal/task" -) - -func Test_dbPath(t *testing.T) { - tmpDataDir := "/tmp/testdata" - oldDataDir := dataDir - dataDir = tmpDataDir - defer func() { dataDir = oldDataDir }() - - tests := []struct { - name string - dbType MaxMindDatabaseType - want string - }{ - {"GeoLite", MaxMindGeoLite, filepath.Join(tmpDataDir, "GeoLite2-City.mmdb")}, - {"GeoIP2", MaxMindGeoIP2, filepath.Join(tmpDataDir, "GeoIP2-City.mmdb")}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := dbPath(tt.dbType); got != tt.want { - t.Errorf("dbPath() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_dbURL(t *testing.T) { - tests := []struct { - name string - dbType MaxMindDatabaseType - want string - }{ - {"GeoLite", MaxMindGeoLite, "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"}, - {"GeoIP2", MaxMindGeoIP2, "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := dbURL(tt.dbType); got != tt.want { - t.Errorf("dbURL() = %v, want %v", got, tt.want) - } - }) - } -} - -// --- Helper for MaxMindConfig --- -type testLogger struct{ zerolog.Logger } - -func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} } -func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} } -func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} } - -func Test_MaxMindConfig_newReq(t *testing.T) { - cfg := &MaxMindConfig{ - AccountID: "testid", - LicenseKey: "testkey", - Database: MaxMindGeoLite, - logger: zerolog.Nop(), - } - - // Patch httpClient to use httptest - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" { - t.Errorf("basic auth not set correctly") - } - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - oldURL := dbURL - dbURL = func(MaxMindDatabaseType) string { return server.URL } - defer func() { dbURL = oldURL }() - - resp, err := cfg.newReq(http.MethodGet) - if err != nil { - t.Fatalf("newReq() error = %v", err) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("unexpected status: %v", resp.StatusCode) - } -} - -func Test_MaxMindConfig_checkUpdate(t *testing.T) { - cfg := &MaxMindConfig{ - AccountID: "id", - LicenseKey: "key", - Database: MaxMindGeoLite, - logger: zerolog.Nop(), - } - lastMod := time.Now().UTC().Format(http.TimeFormat) - buildTime := time.Now().Add(-time.Hour) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Last-Modified", lastMod) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - oldURL := dbURL - dbURL = func(MaxMindDatabaseType) string { return server.URL } - defer func() { dbURL = oldURL }() - - latest, err := cfg.checkLastest() - if err != nil { - t.Fatalf("checkUpdate() error = %v", err) - } - if latest.Equal(buildTime) { - t.Errorf("expected update needed") - } -} - -type fakeReadCloser struct { - firstRead bool - closed bool -} - -func (c *fakeReadCloser) Read(p []byte) (int, error) { - if !c.firstRead { - c.firstRead = true - return strings.NewReader("FAKEMMDB").Read(p) - } - return 0, io.EOF -} - -func (c *fakeReadCloser) Close() error { - c.closed = true - return nil -} - -func Test_MaxMindConfig_download(t *testing.T) { - cfg := &MaxMindConfig{ - AccountID: "id", - LicenseKey: "key", - Database: MaxMindGeoLite, - logger: zerolog.Nop(), - } - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gz := gzip.NewWriter(w) - t := tar.NewWriter(gz) - t.WriteHeader(&tar.Header{ - Name: dbFilename(MaxMindGeoLite), - }) - t.Write([]byte("1234")) - t.Close() - gz.Close() - })) - defer server.Close() - - oldURL := dbURL - dbURL = func(MaxMindDatabaseType) string { return server.URL } - defer func() { dbURL = oldURL }() - - tmpDir := t.TempDir() - oldDataDir := dataDir - dataDir = tmpDir - defer func() { dataDir = oldDataDir }() - - // Patch maxminddb.Open to always succeed - origOpen := maxmindDBOpen - maxmindDBOpen = func(path string) (*maxminddb.Reader, error) { - return &maxminddb.Reader{}, nil - } - defer func() { maxmindDBOpen = origOpen }() - - req, err := http.NewRequest(http.MethodGet, server.URL, nil) - if err != nil { - t.Fatalf("newReq() error = %v", err) - } - - rw := httptest.NewRecorder() - oldNewReq := newReq - newReq = func(cfg *MaxMindConfig, method string) (*http.Response, error) { - server.Config.Handler.ServeHTTP(rw, req) - return rw.Result(), nil - } - defer func() { newReq = oldNewReq }() - - err = cfg.download() - if err != nil { - t.Fatalf("download() error = %v", err) - } - if cfg.db.Reader == nil { - t.Error("expected db instance") - } -} - -func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) { - // This test should cover both the path where DB exists and where it does not - // For brevity, only the non-existing path is tested here - cfg := &MaxMindConfig{ - AccountID: "id", - LicenseKey: "key", - Database: MaxMindGeoLite, - logger: zerolog.Nop(), - } - oldOpen := maxmindDBOpen - maxmindDBOpen = func(path string) (*maxminddb.Reader, error) { - return &maxminddb.Reader{}, nil - } - defer func() { maxmindDBOpen = oldOpen }() - - oldDBPath := dbPath - dbPath = func(MaxMindDatabaseType) string { return filepath.Join(t.TempDir(), "maxmind.mmdb") } - defer func() { dbPath = oldDBPath }() - - task := task.RootTask("test") - defer task.Finish(nil) - err := cfg.LoadMaxMindDB(task) - if err != nil { - t.Errorf("loadMaxMindDB() error = %v", err) - } -} diff --git a/internal/config/config.go b/internal/config/config.go index 99605e4..e352734 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,6 +17,7 @@ import ( "github.com/yusing/go-proxy/internal/entrypoint" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/maxmind" "github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/proxmox" @@ -230,6 +231,7 @@ func (cfg *Config) load() gperr.Error { errs := gperr.NewBuilder(errMsg) errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) + errs.Add(cfg.initMaxMind(model.Providers.MaxMind)) cfg.initNotification(model.Providers.Notification) errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.initProxmox(model.Providers.Proxmox)) @@ -262,6 +264,13 @@ func (cfg *Config) load() gperr.Error { return nil } +func (cfg *Config) initMaxMind(maxmindCfg *maxmind.Config) gperr.Error { + if maxmindCfg != nil { + return maxmind.SetInstance(cfg.task, maxmindCfg) + } + return nil +} + func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) { if len(notifCfg) == 0 { return diff --git a/internal/config/types/config.go b/internal/config/types/config.go index a6396fe..7db6036 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -11,6 +11,7 @@ import ( "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging/accesslog" + maxmind "github.com/yusing/go-proxy/internal/maxmind/types" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/proxmox" "github.com/yusing/go-proxy/internal/utils" @@ -32,6 +33,7 @@ type ( Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty"` Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"` Proxmox []proxmox.Config `json:"proxmox" yaml:"proxmox,omitempty"` + MaxMind *maxmind.Config `json:"maxmind" yaml:"maxmind,omitempty"` } Entrypoint struct { Middlewares []map[string]any `json:"middlewares"` diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go index 978f891..93ea54b 100644 --- a/internal/logging/accesslog/access_logger.go +++ b/internal/logging/accesslog/access_logger.go @@ -9,9 +9,9 @@ import ( "time" "github.com/rs/zerolog" - acl "github.com/yusing/go-proxy/internal/acl/types" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" + maxmind "github.com/yusing/go-proxy/internal/maxmind/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/synk" @@ -61,7 +61,7 @@ type ( } ACLFormatter interface { // AppendACLLog appends a log line to line with or without a trailing newline - AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte + AppendACLLog(line []byte, info *maxmind.IPInfo, blocked bool) []byte } ) @@ -179,7 +179,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) { l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()}) } -func (l *AccessLogger) LogACL(info *acl.IPInfo, blocked bool) { +func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { line := l.lineBufPool.Get() defer l.lineBufPool.Put(line) line = l.ACLFormatter.AppendACLLog(line, info, blocked) diff --git a/internal/logging/accesslog/formatter.go b/internal/logging/accesslog/formatter.go index 6b553ff..196a7ad 100644 --- a/internal/logging/accesslog/formatter.go +++ b/internal/logging/accesslog/formatter.go @@ -8,7 +8,7 @@ import ( "strconv" "github.com/rs/zerolog" - acl "github.com/yusing/go-proxy/internal/acl/types" + maxmind "github.com/yusing/go-proxy/internal/maxmind/types" "github.com/yusing/go-proxy/internal/utils" ) @@ -158,7 +158,7 @@ func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *ht return writer.Bytes() } -func (f ACLLogFormatter) AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte { +func (f ACLLogFormatter) AppendACLLog(line []byte, info *maxmind.IPInfo, blocked bool) []byte { writer := bytes.NewBuffer(line) logger := zerolog.New(writer) event := logger.Info(). diff --git a/internal/acl/city_cache.go b/internal/maxmind/city_cache.go similarity index 66% rename from internal/acl/city_cache.go rename to internal/maxmind/city_cache.go index f18b0fc..3d7d781 100644 --- a/internal/acl/city_cache.go +++ b/internal/maxmind/city_cache.go @@ -1,13 +1,12 @@ -package acl +package maxmind import ( "github.com/puzpuzpuz/xsync/v3" - acl "github.com/yusing/go-proxy/internal/acl/types" ) -var cityCache = xsync.NewMapOf[string, *acl.City]() +var cityCache = xsync.NewMapOf[string, *City]() -func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) { +func (cfg *MaxMind) lookupCity(ip *IPInfo) (*City, bool) { if ip.City != nil { return ip.City, true } @@ -25,7 +24,7 @@ func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) { cfg.db.RLock() defer cfg.db.RUnlock() - city = new(acl.City) + city = new(City) err := cfg.db.Lookup(ip.IP, city) if err != nil { return nil, false diff --git a/internal/maxmind/instance.go b/internal/maxmind/instance.go new file mode 100644 index 0000000..b850626 --- /dev/null +++ b/internal/maxmind/instance.go @@ -0,0 +1,31 @@ +package maxmind + +import ( + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/task" +) + +var instance *MaxMind + +func SetInstance(parent task.Parent, cfg *Config) gperr.Error { + newInstance := &MaxMind{Config: cfg} + if err := newInstance.LoadMaxMindDB(parent); err != nil { + return err + } + if instance != nil { + instance.task.Finish("updated") + } + instance = newInstance + return nil +} + +func HasInstance() bool { + return instance != nil +} + +func LookupCity(ip *IPInfo) (*City, bool) { + if instance == nil { + return nil, false + } + return instance.lookupCity(ip) +} diff --git a/internal/acl/maxmind.go b/internal/maxmind/maxmind.go similarity index 69% rename from internal/acl/maxmind.go rename to internal/maxmind/maxmind.go index eed630e..dad6beb 100644 --- a/internal/acl/maxmind.go +++ b/internal/maxmind/maxmind.go @@ -1,4 +1,4 @@ -package acl +package maxmind import ( "archive/tar" @@ -9,14 +9,32 @@ import ( "net/http" "os" "path/filepath" + "sync" "time" "github.com/oschwald/maxminddb-golang" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" + maxmind "github.com/yusing/go-proxy/internal/maxmind/types" "github.com/yusing/go-proxy/internal/task" ) +type MaxMind struct { + *Config + lastUpdate time.Time + task *task.Task + db struct { + *maxminddb.Reader + sync.RWMutex + } +} + +type ( + Config = maxmind.Config + IPInfo = maxmind.IPInfo + City = maxmind.City +) + var ( updateInterval = 24 * time.Hour httpClient = &http.Client{ @@ -26,33 +44,34 @@ var ( ErrDownloadFailure = gperr.New("download failure") ) -func dbPathImpl(dbType MaxMindDatabaseType) string { - if dbType == MaxMindGeoLite { +func (cfg *MaxMind) dbPath() string { + if cfg.Database == maxmind.MaxMindGeoLite { return filepath.Join(dataDir, "GeoLite2-City.mmdb") } return filepath.Join(dataDir, "GeoIP2-City.mmdb") } -func dbURLimpl(dbType MaxMindDatabaseType) string { - if dbType == MaxMindGeoLite { +func (cfg *MaxMind) dbURL() string { + if cfg.Database == maxmind.MaxMindGeoLite { return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz" } return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz" } -func dbFilename(dbType MaxMindDatabaseType) string { - if dbType == MaxMindGeoLite { +func (cfg *MaxMind) dbFilename() string { + if cfg.Database == maxmind.MaxMindGeoLite { return "GeoLite2-City.mmdb" } return "GeoIP2-City.mmdb" } -func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error { +func (cfg *MaxMind) LoadMaxMindDB(parent task.Parent) gperr.Error { if cfg.Database == "" { return nil } - path := dbPath(cfg.Database) + cfg.task = parent.Subtask("maxmind_db", true) + path := dbPath(cfg) reader, err := maxmindDBOpen(path) valid := true if err != nil { @@ -69,32 +88,32 @@ func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error { } if !valid { - cfg.logger.Info().Msg("MaxMind DB not found/invalid, downloading...") + cfg.Logger().Info().Msg("MaxMind DB not found/invalid, downloading...") if err = cfg.download(); err != nil { return ErrDownloadFailure.With(err) } } else { - cfg.logger.Info().Msg("MaxMind DB loaded") + cfg.Logger().Info().Msg("MaxMind DB loaded") cfg.db.Reader = reader - go cfg.scheduleUpdate(parent) + go cfg.scheduleUpdate(cfg.task) } return nil } -func (cfg *MaxMindConfig) loadLastUpdate() { - f, err := os.Stat(dbPath(cfg.Database)) +func (cfg *MaxMind) loadLastUpdate() { + f, err := os.Stat(cfg.dbPath()) if err != nil { return } cfg.lastUpdate = f.ModTime() } -func (cfg *MaxMindConfig) setLastUpdate(t time.Time) { +func (cfg *MaxMind) setLastUpdate(t time.Time) { cfg.lastUpdate = t - _ = os.Chtimes(dbPath(cfg.Database), t, t) + _ = os.Chtimes(cfg.dbPath(), t, t) } -func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) { +func (cfg *MaxMind) scheduleUpdate(parent task.Parent) { task := parent.Subtask("schedule_update", true) ticker := time.NewTicker(updateInterval) @@ -119,45 +138,45 @@ func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) { } } -func (cfg *MaxMindConfig) update() { +func (cfg *MaxMind) update() { // check for update - cfg.logger.Info().Msg("checking for MaxMind DB update...") + cfg.Logger().Info().Msg("checking for MaxMind DB update...") remoteLastModified, err := cfg.checkLastest() if err != nil { - cfg.logger.Err(err).Msg("failed to check MaxMind DB update") + cfg.Logger().Err(err).Msg("failed to check MaxMind DB update") return } if remoteLastModified.Equal(cfg.lastUpdate) { - cfg.logger.Info().Msg("MaxMind DB is up to date") + cfg.Logger().Info().Msg("MaxMind DB is up to date") return } - cfg.logger.Info(). + cfg.Logger().Info(). Time("latest", remoteLastModified.Local()). Time("current", cfg.lastUpdate). Msg("MaxMind DB update available") if err = cfg.download(); err != nil { - cfg.logger.Err(err).Msg("failed to update MaxMind DB") + cfg.Logger().Err(err).Msg("failed to update MaxMind DB") return } - cfg.logger.Info().Msg("MaxMind DB updated") + cfg.Logger().Info().Msg("MaxMind DB updated") } -func (cfg *MaxMindConfig) newReq(method string) (*http.Response, error) { - req, err := http.NewRequest(method, dbURL(cfg.Database), nil) +func (cfg *MaxMind) doReq(method string) (*http.Response, error) { + req, err := http.NewRequest(method, cfg.dbURL(), nil) if err != nil { return nil, err } req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey) - resp, err := httpClient.Do(req) + resp, err := doReq(req) if err != nil { return nil, err } return resp, nil } -func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) { - resp, err := newReq(cfg, http.MethodHead) +func (cfg *MaxMind) checkLastest() (lastModifiedT *time.Time, err error) { + resp, err := cfg.doReq(http.MethodHead) if err != nil { return nil, err } @@ -169,21 +188,21 @@ func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) { lastModified := resp.Header.Get("Last-Modified") if lastModified == "" { - cfg.logger.Warn().Msg("MaxMind responded no last modified time, update skipped") + cfg.Logger().Warn().Msg("MaxMind responded no last modified time, update skipped") return nil, nil } lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified) if err != nil { - cfg.logger.Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped") + cfg.Logger().Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped") return nil, err } return &lastModifiedTime, nil } -func (cfg *MaxMindConfig) download() error { - resp, err := newReq(cfg, http.MethodGet) +func (cfg *MaxMind) download() error { + resp, err := cfg.doReq(http.MethodGet) if err != nil { return err } @@ -193,7 +212,7 @@ func (cfg *MaxMindConfig) download() error { return fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode) } - dbFile := dbPath(cfg.Database) + dbFile := dbPath(cfg) tmpGZPath := dbFile + "-tmp.tar.gz" tmpDBPath := dbFile + "-tmp" @@ -208,7 +227,7 @@ func (cfg *MaxMindConfig) download() error { _ = os.Remove(tmpGZPath) }() - cfg.logger.Info().Msg("MaxMind DB downloading...") + cfg.Logger().Info().Msg("MaxMind DB downloading...") _, err = io.Copy(tmpGZFile, resp.Body) if err != nil { @@ -220,7 +239,7 @@ func (cfg *MaxMindConfig) download() error { } // extract .tar.gz and to database - err = extractFileFromTarGz(tmpGZFile, dbFilename(cfg.Database), tmpDBPath) + err = extractFileFromTarGz(tmpGZFile, cfg.dbFilename(), tmpDBPath) if err != nil { return gperr.New("failed to extract database from archive").With(err) @@ -255,7 +274,7 @@ func (cfg *MaxMindConfig) download() error { cfg.setLastUpdate(lastModifiedTime) } - cfg.logger.Info().Msg("MaxMind DB downloaded") + cfg.Logger().Info().Msg("MaxMind DB downloaded") return nil } @@ -296,8 +315,7 @@ func extractFileFromTarGz(tarGzFile *os.File, targetFilename, destPath string) e var ( dataDir = common.DataDir - dbURL = dbURLimpl - dbPath = dbPathImpl + dbPath = (*MaxMind).dbPath + doReq = httpClient.Do maxmindDBOpen = maxminddb.Open - newReq = (*MaxMindConfig).newReq ) diff --git a/internal/maxmind/maxmind_test.go b/internal/maxmind/maxmind_test.go new file mode 100644 index 0000000..0478530 --- /dev/null +++ b/internal/maxmind/maxmind_test.go @@ -0,0 +1,131 @@ +package maxmind + +import ( + "archive/tar" + "compress/gzip" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/oschwald/maxminddb-golang" + "github.com/rs/zerolog" + maxmind "github.com/yusing/go-proxy/internal/maxmind/types" + "github.com/yusing/go-proxy/internal/task" +) + +// --- Helper for MaxMindConfig --- +type testLogger struct{ zerolog.Logger } + +func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} } +func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} } +func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} } + +func testCfg() *MaxMind { + return &MaxMind{ + Config: &Config{ + AccountID: "testid", + LicenseKey: "testkey", + Database: maxmind.MaxMindGeoLite, + }, + } +} + +var testLastMod = time.Now().UTC() + +func testDoReq(cfg *MaxMind, w http.ResponseWriter, r *http.Request) { + if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat)) + gz := gzip.NewWriter(w) + t := tar.NewWriter(gz) + t.WriteHeader(&tar.Header{ + Name: cfg.dbFilename(), + }) + t.Write([]byte("1234")) + t.Close() + gz.Close() + w.WriteHeader(http.StatusOK) +} + +func mockDoReq(cfg *MaxMind, t *testing.T) { + rw := httptest.NewRecorder() + oldDoReq := doReq + doReq = func(req *http.Request) (*http.Response, error) { + testDoReq(cfg, rw, req) + return rw.Result(), nil + } + t.Cleanup(func() { doReq = oldDoReq }) +} + +func mockDataDir(t *testing.T) { + oldDataDir := dataDir + dataDir = t.TempDir() + t.Cleanup(func() { dataDir = oldDataDir }) +} + +func mockMaxMindDBOpen(t *testing.T) { + oldMaxMindDBOpen := maxmindDBOpen + maxmindDBOpen = func(path string) (*maxminddb.Reader, error) { + return &maxminddb.Reader{}, nil + } + t.Cleanup(func() { maxmindDBOpen = oldMaxMindDBOpen }) +} + +func Test_MaxMindConfig_doReq(t *testing.T) { + cfg := testCfg() + mockDoReq(cfg, t) + resp, err := cfg.doReq(http.MethodGet) + if err != nil { + t.Fatalf("newReq() error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("unexpected status: %v", resp.StatusCode) + } +} + +func Test_MaxMindConfig_checkLatest(t *testing.T) { + cfg := testCfg() + mockDoReq(cfg, t) + + latest, err := cfg.checkLastest() + if err != nil { + t.Fatalf("checkLatest() error = %v", err) + } + if latest.Equal(testLastMod) { + t.Errorf("expected latest equal to testLastMod") + } +} + +func Test_MaxMindConfig_download(t *testing.T) { + cfg := testCfg() + mockDataDir(t) + mockMaxMindDBOpen(t) + mockDoReq(cfg, t) + + err := cfg.download() + if err != nil { + t.Fatalf("download() error = %v", err) + } + if cfg.db.Reader == nil { + t.Error("expected db instance") + } +} + +func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) { + cfg := testCfg() + mockDataDir(t) + mockMaxMindDBOpen(t) + + task := task.RootTask("test") + defer task.Finish(nil) + err := cfg.LoadMaxMindDB(task) + if err != nil { + t.Errorf("loadMaxMindDB() error = %v", err) + } + if cfg.db.Reader == nil { + t.Error("expected db instance") + } +} diff --git a/internal/acl/types/city_info.go b/internal/maxmind/types/city_info.go similarity index 92% rename from internal/acl/types/city_info.go rename to internal/maxmind/types/city_info.go index 05b4315..98f1694 100644 --- a/internal/acl/types/city_info.go +++ b/internal/maxmind/types/city_info.go @@ -1,4 +1,4 @@ -package acl +package maxmind type City struct { Location struct { diff --git a/internal/maxmind/types/config.go b/internal/maxmind/types/config.go new file mode 100644 index 0000000..298f3a5 --- /dev/null +++ b/internal/maxmind/types/config.go @@ -0,0 +1,33 @@ +package maxmind + +import ( + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/logging" +) + +type ( + DatabaseType string + Config struct { + AccountID string `json:"account_id" validate:"required"` + LicenseKey string `json:"license_key" validate:"required"` + Database DatabaseType `json:"database" validate:"omitempty,oneof=geolite geoip2"` + } +) + +const ( + MaxMindGeoLite DatabaseType = "geolite" + MaxMindGeoIP2 DatabaseType = "geoip2" +) + +func (cfg *Config) Validate() gperr.Error { + if cfg.Database == "" { + cfg.Database = MaxMindGeoLite + } + return nil +} + +func (cfg *Config) Logger() *zerolog.Logger { + l := logging.With().Str("database", string(cfg.Database)).Logger() + return &l +} diff --git a/internal/acl/types/ip_info.go b/internal/maxmind/types/ip_info.go similarity index 82% rename from internal/acl/types/ip_info.go rename to internal/maxmind/types/ip_info.go index 13dec8b..c7329af 100644 --- a/internal/acl/types/ip_info.go +++ b/internal/maxmind/types/ip_info.go @@ -1,4 +1,4 @@ -package acl +package maxmind import "net"