breaking: move maxmind config to config.providers

- moved maxmind to separate module
- code refactored
- simplified test
This commit is contained in:
yusing 2025-05-03 20:58:09 +08:00
parent ac1470d81d
commit ccb4639f43
14 changed files with 314 additions and 369 deletions

View file

@ -2,17 +2,13 @@ package acl
import (
"net"
"sync"
"time"
"github.com/oschwald/maxminddb-golang"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog"
acl "github.com/yusing/go-proxy/internal/acl/types"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
@ -20,43 +16,23 @@ import (
type Config struct {
Default string `json:"default" validate:"omitempty,oneof=allow deny"` // default: allow
AllowLocal *bool `json:"allow_local"` // default: true
Allow []string `json:"allow"`
Deny []string `json:"deny"`
Allow Matchers `json:"allow"`
Deny Matchers `json:"deny"`
Log *accesslog.ACLLoggerConfig `json:"log"`
MaxMind *MaxMindConfig `json:"maxmind" validate:"omitempty"`
config
}
type (
MaxMindDatabaseType string
MaxMindConfig struct {
AccountID string `json:"account_id" validate:"required"`
LicenseKey string `json:"license_key" validate:"required"`
Database MaxMindDatabaseType `json:"database" validate:"required,oneof=geolite geoip2"`
logger zerolog.Logger
lastUpdate time.Time
db struct {
*maxminddb.Reader
sync.RWMutex
}
}
)
type config struct {
defaultAllow bool
allowLocal bool
allow []matcher
deny []matcher
ipCache *xsync.MapOf[string, *checkCache]
logAllowed bool
logger *accesslog.AccessLogger
}
type checkCache struct {
*acl.IPInfo
*maxmind.IPInfo
allow bool
created time.Time
}
@ -74,11 +50,6 @@ const (
ACLDeny = "deny"
)
const (
MaxMindGeoLite MaxMindDatabaseType = "geolite"
MaxMindGeoIP2 MaxMindDatabaseType = "geoip2"
)
func (c *Config) Validate() gperr.Error {
switch c.Default {
case "", ACLAllow:
@ -95,55 +66,19 @@ func (c *Config) Validate() gperr.Error {
c.allowLocal = true
}
if c.MaxMind != nil {
c.MaxMind.logger = logging.With().Str("type", string(c.MaxMind.Database)).Logger()
}
if c.Log != nil {
c.logAllowed = c.Log.LogAllowed
}
errs := gperr.NewBuilder("syntax error")
c.allow = make([]matcher, 0, len(c.Allow))
c.deny = make([]matcher, 0, len(c.Deny))
for _, s := range c.Allow {
m, err := c.parseMatcher(s)
if err != nil {
errs.Add(err.Subject(s))
continue
}
c.allow = append(c.allow, m)
}
for _, s := range c.Deny {
m, err := c.parseMatcher(s)
if err != nil {
errs.Add(err.Subject(s))
continue
}
c.deny = append(c.deny, m)
}
if errs.HasError() {
c.allow = nil
c.deny = nil
return errMatcherFormat.With(errs.Error())
}
c.ipCache = xsync.NewMapOf[string, *checkCache]()
return nil
}
func (c *Config) Valid() bool {
return c != nil && (len(c.allow) > 0 || len(c.deny) > 0 || c.allowLocal)
return c != nil && (len(c.Allow) > 0 || len(c.Deny) > 0 || c.allowLocal)
}
func (c *Config) Start(parent *task.Task) gperr.Error {
if c.MaxMind != nil {
if err := c.MaxMind.LoadMaxMindDB(parent); err != nil {
return err
}
}
if c.Log != nil {
logger, err := accesslog.NewAccessLogger(parent, c.Log)
if err != nil {
@ -154,9 +89,9 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
return nil
}
func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) {
func (c *Config) cacheRecord(info *maxmind.IPInfo, allow bool) {
if common.ForceResolveCountry && info.City == nil {
c.MaxMind.lookupCity(info)
maxmind.LookupCity(info)
}
c.ipCache.Store(info.Str, &checkCache{
IPInfo: info,
@ -165,7 +100,7 @@ func (c *Config) cacheRecord(info *acl.IPInfo, allow bool) {
})
}
func (c *config) log(info *acl.IPInfo, allowed bool) {
func (c *config) log(info *maxmind.IPInfo, allowed bool) {
if c.logger == nil {
return
}
@ -186,7 +121,7 @@ func (c *Config) IPAllowed(ip net.IP) bool {
}
if c.allowLocal && ip.IsPrivate() {
c.log(&acl.IPInfo{IP: ip, Str: ip.String()}, true)
c.log(&maxmind.IPInfo{IP: ip, Str: ip.String()}, true)
return true
}
@ -197,15 +132,15 @@ func (c *Config) IPAllowed(ip net.IP) bool {
return record.allow
}
ipAndStr := &acl.IPInfo{IP: ip, Str: ipStr}
for _, m := range c.allow {
ipAndStr := &maxmind.IPInfo{IP: ip, Str: ipStr}
for _, m := range c.Allow {
if m(ipAndStr) {
c.log(ipAndStr, true)
c.cacheRecord(ipAndStr, true)
return true
}
}
for _, m := range c.deny {
for _, m := range c.Deny {
if m(ipAndStr) {
c.log(ipAndStr, false)
c.cacheRecord(ipAndStr, false)

View file

@ -4,11 +4,12 @@ import (
"net"
"strings"
acl "github.com/yusing/go-proxy/internal/acl/types"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/maxmind"
)
type matcher func(*acl.IPInfo) bool
type Matcher func(*maxmind.IPInfo) bool
type Matchers []Matcher
const (
MatcherTypeIP = "ip"
@ -32,7 +33,7 @@ var (
errMaxMindNotConfigured = gperr.New("MaxMind not configured")
)
func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) {
func ParseMatcher(s string) (Matcher, gperr.Error) {
parts := strings.Split(s, ":")
if len(parts) != 2 {
return nil, errSyntax
@ -52,35 +53,44 @@ func (cfg *Config) parseMatcher(s string) (matcher, gperr.Error) {
}
return matchCIDR(net), nil
case MatcherTypeTimeZone:
if cfg.MaxMind == nil {
if !maxmind.HasInstance() {
return nil, errMaxMindNotConfigured
}
return cfg.MaxMind.matchTimeZone(parts[1]), nil
return matchTimeZone(parts[1]), nil
case MatcherTypeCountry:
if cfg.MaxMind == nil {
if !maxmind.HasInstance() {
return nil, errMaxMindNotConfigured
}
return cfg.MaxMind.matchISOCode(parts[1]), nil
return matchISOCode(parts[1]), nil
default:
return nil, errSyntax
}
}
func matchIP(ip net.IP) matcher {
return func(ip2 *acl.IPInfo) bool {
func (matchers Matchers) Match(ip *maxmind.IPInfo) bool {
for _, m := range matchers {
if m(ip) {
return true
}
}
return false
}
func matchIP(ip net.IP) Matcher {
return func(ip2 *maxmind.IPInfo) bool {
return ip.Equal(ip2.IP)
}
}
func matchCIDR(n *net.IPNet) matcher {
return func(ip *acl.IPInfo) bool {
func matchCIDR(n *net.IPNet) Matcher {
return func(ip *maxmind.IPInfo) bool {
return n.Contains(ip.IP)
}
}
func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher {
return func(ip *acl.IPInfo) bool {
city, ok := cfg.lookupCity(ip)
func matchTimeZone(tz string) Matcher {
return func(ip *maxmind.IPInfo) bool {
city, ok := maxmind.LookupCity(ip)
if !ok {
return false
}
@ -88,9 +98,9 @@ func (cfg *MaxMindConfig) matchTimeZone(tz string) matcher {
}
}
func (cfg *MaxMindConfig) matchISOCode(iso string) matcher {
return func(ip *acl.IPInfo) bool {
city, ok := cfg.lookupCity(ip)
func matchISOCode(iso string) Matcher {
return func(ip *maxmind.IPInfo) bool {
city, ok := maxmind.LookupCity(ip)
if !ok {
return false
}

View file

@ -1,223 +0,0 @@
package acl
import (
"archive/tar"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"github.com/oschwald/maxminddb-golang"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/task"
)
func Test_dbPath(t *testing.T) {
tmpDataDir := "/tmp/testdata"
oldDataDir := dataDir
dataDir = tmpDataDir
defer func() { dataDir = oldDataDir }()
tests := []struct {
name string
dbType MaxMindDatabaseType
want string
}{
{"GeoLite", MaxMindGeoLite, filepath.Join(tmpDataDir, "GeoLite2-City.mmdb")},
{"GeoIP2", MaxMindGeoIP2, filepath.Join(tmpDataDir, "GeoIP2-City.mmdb")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := dbPath(tt.dbType); got != tt.want {
t.Errorf("dbPath() = %v, want %v", got, tt.want)
}
})
}
}
func Test_dbURL(t *testing.T) {
tests := []struct {
name string
dbType MaxMindDatabaseType
want string
}{
{"GeoLite", MaxMindGeoLite, "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"},
{"GeoIP2", MaxMindGeoIP2, "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := dbURL(tt.dbType); got != tt.want {
t.Errorf("dbURL() = %v, want %v", got, tt.want)
}
})
}
}
// --- Helper for MaxMindConfig ---
type testLogger struct{ zerolog.Logger }
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
func Test_MaxMindConfig_newReq(t *testing.T) {
cfg := &MaxMindConfig{
AccountID: "testid",
LicenseKey: "testkey",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
// Patch httpClient to use httptest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" {
t.Errorf("basic auth not set correctly")
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
oldURL := dbURL
dbURL = func(MaxMindDatabaseType) string { return server.URL }
defer func() { dbURL = oldURL }()
resp, err := cfg.newReq(http.MethodGet)
if err != nil {
t.Fatalf("newReq() error = %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("unexpected status: %v", resp.StatusCode)
}
}
func Test_MaxMindConfig_checkUpdate(t *testing.T) {
cfg := &MaxMindConfig{
AccountID: "id",
LicenseKey: "key",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
lastMod := time.Now().UTC().Format(http.TimeFormat)
buildTime := time.Now().Add(-time.Hour)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Last-Modified", lastMod)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
oldURL := dbURL
dbURL = func(MaxMindDatabaseType) string { return server.URL }
defer func() { dbURL = oldURL }()
latest, err := cfg.checkLastest()
if err != nil {
t.Fatalf("checkUpdate() error = %v", err)
}
if latest.Equal(buildTime) {
t.Errorf("expected update needed")
}
}
type fakeReadCloser struct {
firstRead bool
closed bool
}
func (c *fakeReadCloser) Read(p []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
return strings.NewReader("FAKEMMDB").Read(p)
}
return 0, io.EOF
}
func (c *fakeReadCloser) Close() error {
c.closed = true
return nil
}
func Test_MaxMindConfig_download(t *testing.T) {
cfg := &MaxMindConfig{
AccountID: "id",
LicenseKey: "key",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gz := gzip.NewWriter(w)
t := tar.NewWriter(gz)
t.WriteHeader(&tar.Header{
Name: dbFilename(MaxMindGeoLite),
})
t.Write([]byte("1234"))
t.Close()
gz.Close()
}))
defer server.Close()
oldURL := dbURL
dbURL = func(MaxMindDatabaseType) string { return server.URL }
defer func() { dbURL = oldURL }()
tmpDir := t.TempDir()
oldDataDir := dataDir
dataDir = tmpDir
defer func() { dataDir = oldDataDir }()
// Patch maxminddb.Open to always succeed
origOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil
}
defer func() { maxmindDBOpen = origOpen }()
req, err := http.NewRequest(http.MethodGet, server.URL, nil)
if err != nil {
t.Fatalf("newReq() error = %v", err)
}
rw := httptest.NewRecorder()
oldNewReq := newReq
newReq = func(cfg *MaxMindConfig, method string) (*http.Response, error) {
server.Config.Handler.ServeHTTP(rw, req)
return rw.Result(), nil
}
defer func() { newReq = oldNewReq }()
err = cfg.download()
if err != nil {
t.Fatalf("download() error = %v", err)
}
if cfg.db.Reader == nil {
t.Error("expected db instance")
}
}
func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
// This test should cover both the path where DB exists and where it does not
// For brevity, only the non-existing path is tested here
cfg := &MaxMindConfig{
AccountID: "id",
LicenseKey: "key",
Database: MaxMindGeoLite,
logger: zerolog.Nop(),
}
oldOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil
}
defer func() { maxmindDBOpen = oldOpen }()
oldDBPath := dbPath
dbPath = func(MaxMindDatabaseType) string { return filepath.Join(t.TempDir(), "maxmind.mmdb") }
defer func() { dbPath = oldDBPath }()
task := task.RootTask("test")
defer task.Finish(nil)
err := cfg.LoadMaxMindDB(task)
if err != nil {
t.Errorf("loadMaxMindDB() error = %v", err)
}
}

View file

@ -17,6 +17,7 @@ import (
"github.com/yusing/go-proxy/internal/entrypoint"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
@ -230,6 +231,7 @@ func (cfg *Config) load() gperr.Error {
errs := gperr.NewBuilder(errMsg)
errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares))
errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog))
errs.Add(cfg.initMaxMind(model.Providers.MaxMind))
cfg.initNotification(model.Providers.Notification)
errs.Add(cfg.initAutoCert(model.AutoCert))
errs.Add(cfg.initProxmox(model.Providers.Proxmox))
@ -262,6 +264,13 @@ func (cfg *Config) load() gperr.Error {
return nil
}
func (cfg *Config) initMaxMind(maxmindCfg *maxmind.Config) gperr.Error {
if maxmindCfg != nil {
return maxmind.SetInstance(cfg.task, maxmindCfg)
}
return nil
}
func (cfg *Config) initNotification(notifCfg []notif.NotificationConfig) {
if len(notifCfg) == 0 {
return

View file

@ -11,6 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging/accesslog"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/utils"
@ -32,6 +33,7 @@ type (
Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty"`
Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"`
Proxmox []proxmox.Config `json:"proxmox" yaml:"proxmox,omitempty"`
MaxMind *maxmind.Config `json:"maxmind" yaml:"maxmind,omitempty"`
}
Entrypoint struct {
Middlewares []map[string]any `json:"middlewares"`

View file

@ -9,9 +9,9 @@ import (
"time"
"github.com/rs/zerolog"
acl "github.com/yusing/go-proxy/internal/acl/types"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/utils/synk"
@ -61,7 +61,7 @@ type (
}
ACLFormatter interface {
// AppendACLLog appends a log line to line with or without a trailing newline
AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte
AppendACLLog(line []byte, info *maxmind.IPInfo, blocked bool) []byte
}
)
@ -179,7 +179,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
l.Log(req, &http.Response{StatusCode: http.StatusInternalServerError, Status: err.Error()})
}
func (l *AccessLogger) LogACL(info *acl.IPInfo, blocked bool) {
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := l.lineBufPool.Get()
defer l.lineBufPool.Put(line)
line = l.ACLFormatter.AppendACLLog(line, info, blocked)

View file

@ -8,7 +8,7 @@ import (
"strconv"
"github.com/rs/zerolog"
acl "github.com/yusing/go-proxy/internal/acl/types"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/utils"
)
@ -158,7 +158,7 @@ func (f *JSONFormatter) AppendRequestLog(line []byte, req *http.Request, res *ht
return writer.Bytes()
}
func (f ACLLogFormatter) AppendACLLog(line []byte, info *acl.IPInfo, blocked bool) []byte {
func (f ACLLogFormatter) AppendACLLog(line []byte, info *maxmind.IPInfo, blocked bool) []byte {
writer := bytes.NewBuffer(line)
logger := zerolog.New(writer)
event := logger.Info().

View file

@ -1,13 +1,12 @@
package acl
package maxmind
import (
"github.com/puzpuzpuz/xsync/v3"
acl "github.com/yusing/go-proxy/internal/acl/types"
)
var cityCache = xsync.NewMapOf[string, *acl.City]()
var cityCache = xsync.NewMapOf[string, *City]()
func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) {
func (cfg *MaxMind) lookupCity(ip *IPInfo) (*City, bool) {
if ip.City != nil {
return ip.City, true
}
@ -25,7 +24,7 @@ func (cfg *MaxMindConfig) lookupCity(ip *acl.IPInfo) (*acl.City, bool) {
cfg.db.RLock()
defer cfg.db.RUnlock()
city = new(acl.City)
city = new(City)
err := cfg.db.Lookup(ip.IP, city)
if err != nil {
return nil, false

View file

@ -0,0 +1,31 @@
package maxmind
import (
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/task"
)
var instance *MaxMind
func SetInstance(parent task.Parent, cfg *Config) gperr.Error {
newInstance := &MaxMind{Config: cfg}
if err := newInstance.LoadMaxMindDB(parent); err != nil {
return err
}
if instance != nil {
instance.task.Finish("updated")
}
instance = newInstance
return nil
}
func HasInstance() bool {
return instance != nil
}
func LookupCity(ip *IPInfo) (*City, bool) {
if instance == nil {
return nil, false
}
return instance.lookupCity(ip)
}

View file

@ -1,4 +1,4 @@
package acl
package maxmind
import (
"archive/tar"
@ -9,14 +9,32 @@ import (
"net/http"
"os"
"path/filepath"
"sync"
"time"
"github.com/oschwald/maxminddb-golang"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task"
)
type MaxMind struct {
*Config
lastUpdate time.Time
task *task.Task
db struct {
*maxminddb.Reader
sync.RWMutex
}
}
type (
Config = maxmind.Config
IPInfo = maxmind.IPInfo
City = maxmind.City
)
var (
updateInterval = 24 * time.Hour
httpClient = &http.Client{
@ -26,33 +44,34 @@ var (
ErrDownloadFailure = gperr.New("download failure")
)
func dbPathImpl(dbType MaxMindDatabaseType) string {
if dbType == MaxMindGeoLite {
func (cfg *MaxMind) dbPath() string {
if cfg.Database == maxmind.MaxMindGeoLite {
return filepath.Join(dataDir, "GeoLite2-City.mmdb")
}
return filepath.Join(dataDir, "GeoIP2-City.mmdb")
}
func dbURLimpl(dbType MaxMindDatabaseType) string {
if dbType == MaxMindGeoLite {
func (cfg *MaxMind) dbURL() string {
if cfg.Database == maxmind.MaxMindGeoLite {
return "https://download.maxmind.com/geoip/databases/GeoLite2-City/download?suffix=tar.gz"
}
return "https://download.maxmind.com/geoip/databases/GeoIP2-City/download?suffix=tar.gz"
}
func dbFilename(dbType MaxMindDatabaseType) string {
if dbType == MaxMindGeoLite {
func (cfg *MaxMind) dbFilename() string {
if cfg.Database == maxmind.MaxMindGeoLite {
return "GeoLite2-City.mmdb"
}
return "GeoIP2-City.mmdb"
}
func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error {
func (cfg *MaxMind) LoadMaxMindDB(parent task.Parent) gperr.Error {
if cfg.Database == "" {
return nil
}
path := dbPath(cfg.Database)
cfg.task = parent.Subtask("maxmind_db", true)
path := dbPath(cfg)
reader, err := maxmindDBOpen(path)
valid := true
if err != nil {
@ -69,32 +88,32 @@ func (cfg *MaxMindConfig) LoadMaxMindDB(parent task.Parent) gperr.Error {
}
if !valid {
cfg.logger.Info().Msg("MaxMind DB not found/invalid, downloading...")
cfg.Logger().Info().Msg("MaxMind DB not found/invalid, downloading...")
if err = cfg.download(); err != nil {
return ErrDownloadFailure.With(err)
}
} else {
cfg.logger.Info().Msg("MaxMind DB loaded")
cfg.Logger().Info().Msg("MaxMind DB loaded")
cfg.db.Reader = reader
go cfg.scheduleUpdate(parent)
go cfg.scheduleUpdate(cfg.task)
}
return nil
}
func (cfg *MaxMindConfig) loadLastUpdate() {
f, err := os.Stat(dbPath(cfg.Database))
func (cfg *MaxMind) loadLastUpdate() {
f, err := os.Stat(cfg.dbPath())
if err != nil {
return
}
cfg.lastUpdate = f.ModTime()
}
func (cfg *MaxMindConfig) setLastUpdate(t time.Time) {
func (cfg *MaxMind) setLastUpdate(t time.Time) {
cfg.lastUpdate = t
_ = os.Chtimes(dbPath(cfg.Database), t, t)
_ = os.Chtimes(cfg.dbPath(), t, t)
}
func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) {
func (cfg *MaxMind) scheduleUpdate(parent task.Parent) {
task := parent.Subtask("schedule_update", true)
ticker := time.NewTicker(updateInterval)
@ -119,45 +138,45 @@ func (cfg *MaxMindConfig) scheduleUpdate(parent task.Parent) {
}
}
func (cfg *MaxMindConfig) update() {
func (cfg *MaxMind) update() {
// check for update
cfg.logger.Info().Msg("checking for MaxMind DB update...")
cfg.Logger().Info().Msg("checking for MaxMind DB update...")
remoteLastModified, err := cfg.checkLastest()
if err != nil {
cfg.logger.Err(err).Msg("failed to check MaxMind DB update")
cfg.Logger().Err(err).Msg("failed to check MaxMind DB update")
return
}
if remoteLastModified.Equal(cfg.lastUpdate) {
cfg.logger.Info().Msg("MaxMind DB is up to date")
cfg.Logger().Info().Msg("MaxMind DB is up to date")
return
}
cfg.logger.Info().
cfg.Logger().Info().
Time("latest", remoteLastModified.Local()).
Time("current", cfg.lastUpdate).
Msg("MaxMind DB update available")
if err = cfg.download(); err != nil {
cfg.logger.Err(err).Msg("failed to update MaxMind DB")
cfg.Logger().Err(err).Msg("failed to update MaxMind DB")
return
}
cfg.logger.Info().Msg("MaxMind DB updated")
cfg.Logger().Info().Msg("MaxMind DB updated")
}
func (cfg *MaxMindConfig) newReq(method string) (*http.Response, error) {
req, err := http.NewRequest(method, dbURL(cfg.Database), nil)
func (cfg *MaxMind) doReq(method string) (*http.Response, error) {
req, err := http.NewRequest(method, cfg.dbURL(), nil)
if err != nil {
return nil, err
}
req.SetBasicAuth(cfg.AccountID, cfg.LicenseKey)
resp, err := httpClient.Do(req)
resp, err := doReq(req)
if err != nil {
return nil, err
}
return resp, nil
}
func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) {
resp, err := newReq(cfg, http.MethodHead)
func (cfg *MaxMind) checkLastest() (lastModifiedT *time.Time, err error) {
resp, err := cfg.doReq(http.MethodHead)
if err != nil {
return nil, err
}
@ -169,21 +188,21 @@ func (cfg *MaxMindConfig) checkLastest() (lastModifiedT *time.Time, err error) {
lastModified := resp.Header.Get("Last-Modified")
if lastModified == "" {
cfg.logger.Warn().Msg("MaxMind responded no last modified time, update skipped")
cfg.Logger().Warn().Msg("MaxMind responded no last modified time, update skipped")
return nil, nil
}
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified)
if err != nil {
cfg.logger.Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped")
cfg.Logger().Warn().Err(err).Msg("MaxMind responded invalid last modified time, update skipped")
return nil, err
}
return &lastModifiedTime, nil
}
func (cfg *MaxMindConfig) download() error {
resp, err := newReq(cfg, http.MethodGet)
func (cfg *MaxMind) download() error {
resp, err := cfg.doReq(http.MethodGet)
if err != nil {
return err
}
@ -193,7 +212,7 @@ func (cfg *MaxMindConfig) download() error {
return fmt.Errorf("%w: %d", ErrResponseNotOK, resp.StatusCode)
}
dbFile := dbPath(cfg.Database)
dbFile := dbPath(cfg)
tmpGZPath := dbFile + "-tmp.tar.gz"
tmpDBPath := dbFile + "-tmp"
@ -208,7 +227,7 @@ func (cfg *MaxMindConfig) download() error {
_ = os.Remove(tmpGZPath)
}()
cfg.logger.Info().Msg("MaxMind DB downloading...")
cfg.Logger().Info().Msg("MaxMind DB downloading...")
_, err = io.Copy(tmpGZFile, resp.Body)
if err != nil {
@ -220,7 +239,7 @@ func (cfg *MaxMindConfig) download() error {
}
// extract .tar.gz and to database
err = extractFileFromTarGz(tmpGZFile, dbFilename(cfg.Database), tmpDBPath)
err = extractFileFromTarGz(tmpGZFile, cfg.dbFilename(), tmpDBPath)
if err != nil {
return gperr.New("failed to extract database from archive").With(err)
@ -255,7 +274,7 @@ func (cfg *MaxMindConfig) download() error {
cfg.setLastUpdate(lastModifiedTime)
}
cfg.logger.Info().Msg("MaxMind DB downloaded")
cfg.Logger().Info().Msg("MaxMind DB downloaded")
return nil
}
@ -296,8 +315,7 @@ func extractFileFromTarGz(tarGzFile *os.File, targetFilename, destPath string) e
var (
dataDir = common.DataDir
dbURL = dbURLimpl
dbPath = dbPathImpl
dbPath = (*MaxMind).dbPath
doReq = httpClient.Do
maxmindDBOpen = maxminddb.Open
newReq = (*MaxMindConfig).newReq
)

View file

@ -0,0 +1,131 @@
package maxmind
import (
"archive/tar"
"compress/gzip"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/oschwald/maxminddb-golang"
"github.com/rs/zerolog"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task"
)
// --- Helper for MaxMindConfig ---
type testLogger struct{ zerolog.Logger }
func (testLogger) Info() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Warn() *zerolog.Event { return &zerolog.Event{} }
func (testLogger) Err(_ error) *zerolog.Event { return &zerolog.Event{} }
func testCfg() *MaxMind {
return &MaxMind{
Config: &Config{
AccountID: "testid",
LicenseKey: "testkey",
Database: maxmind.MaxMindGeoLite,
},
}
}
var testLastMod = time.Now().UTC()
func testDoReq(cfg *MaxMind, w http.ResponseWriter, r *http.Request) {
if u, p, ok := r.BasicAuth(); !ok || u != "testid" || p != "testkey" {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Last-Modified", testLastMod.Format(http.TimeFormat))
gz := gzip.NewWriter(w)
t := tar.NewWriter(gz)
t.WriteHeader(&tar.Header{
Name: cfg.dbFilename(),
})
t.Write([]byte("1234"))
t.Close()
gz.Close()
w.WriteHeader(http.StatusOK)
}
func mockDoReq(cfg *MaxMind, t *testing.T) {
rw := httptest.NewRecorder()
oldDoReq := doReq
doReq = func(req *http.Request) (*http.Response, error) {
testDoReq(cfg, rw, req)
return rw.Result(), nil
}
t.Cleanup(func() { doReq = oldDoReq })
}
func mockDataDir(t *testing.T) {
oldDataDir := dataDir
dataDir = t.TempDir()
t.Cleanup(func() { dataDir = oldDataDir })
}
func mockMaxMindDBOpen(t *testing.T) {
oldMaxMindDBOpen := maxmindDBOpen
maxmindDBOpen = func(path string) (*maxminddb.Reader, error) {
return &maxminddb.Reader{}, nil
}
t.Cleanup(func() { maxmindDBOpen = oldMaxMindDBOpen })
}
func Test_MaxMindConfig_doReq(t *testing.T) {
cfg := testCfg()
mockDoReq(cfg, t)
resp, err := cfg.doReq(http.MethodGet)
if err != nil {
t.Fatalf("newReq() error = %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("unexpected status: %v", resp.StatusCode)
}
}
func Test_MaxMindConfig_checkLatest(t *testing.T) {
cfg := testCfg()
mockDoReq(cfg, t)
latest, err := cfg.checkLastest()
if err != nil {
t.Fatalf("checkLatest() error = %v", err)
}
if latest.Equal(testLastMod) {
t.Errorf("expected latest equal to testLastMod")
}
}
func Test_MaxMindConfig_download(t *testing.T) {
cfg := testCfg()
mockDataDir(t)
mockMaxMindDBOpen(t)
mockDoReq(cfg, t)
err := cfg.download()
if err != nil {
t.Fatalf("download() error = %v", err)
}
if cfg.db.Reader == nil {
t.Error("expected db instance")
}
}
func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
cfg := testCfg()
mockDataDir(t)
mockMaxMindDBOpen(t)
task := task.RootTask("test")
defer task.Finish(nil)
err := cfg.LoadMaxMindDB(task)
if err != nil {
t.Errorf("loadMaxMindDB() error = %v", err)
}
if cfg.db.Reader == nil {
t.Error("expected db instance")
}
}

View file

@ -1,4 +1,4 @@
package acl
package maxmind
type City struct {
Location struct {

View file

@ -0,0 +1,33 @@
package maxmind
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
)
type (
DatabaseType string
Config struct {
AccountID string `json:"account_id" validate:"required"`
LicenseKey string `json:"license_key" validate:"required"`
Database DatabaseType `json:"database" validate:"omitempty,oneof=geolite geoip2"`
}
)
const (
MaxMindGeoLite DatabaseType = "geolite"
MaxMindGeoIP2 DatabaseType = "geoip2"
)
func (cfg *Config) Validate() gperr.Error {
if cfg.Database == "" {
cfg.Database = MaxMindGeoLite
}
return nil
}
func (cfg *Config) Logger() *zerolog.Logger {
l := logging.With().Str("database", string(cfg.Database)).Logger()
return &l
}

View file

@ -1,4 +1,4 @@
package acl
package maxmind
import "net"