diff --git a/internal/acl/config.go b/internal/acl/config.go index f058013..90fe152 100644 --- a/internal/acl/config.go +++ b/internal/acl/config.go @@ -133,19 +133,15 @@ func (c *Config) IPAllowed(ip net.IP) bool { } ipAndStr := &maxmind.IPInfo{IP: ip, Str: ipStr} - for _, m := range c.Allow { - if m(ipAndStr) { - c.log(ipAndStr, true) - c.cacheRecord(ipAndStr, true) - return true - } + if c.Allow.Match(ipAndStr) { + c.log(ipAndStr, true) + c.cacheRecord(ipAndStr, true) + return true } - for _, m := range c.Deny { - if m(ipAndStr) { - c.log(ipAndStr, false) - c.cacheRecord(ipAndStr, false) - return false - } + if c.Deny.Match(ipAndStr) { + c.log(ipAndStr, false) + c.cacheRecord(ipAndStr, false) + return false } c.log(ipAndStr, c.defaultAllow) diff --git a/internal/acl/matcher.go b/internal/acl/matcher.go index 34984ce..a9d99b4 100644 --- a/internal/acl/matcher.go +++ b/internal/acl/matcher.go @@ -1,6 +1,7 @@ package acl import ( + "errors" "net" "strings" @@ -8,7 +9,11 @@ import ( "github.com/yusing/go-proxy/internal/maxmind" ) -type Matcher func(*maxmind.IPInfo) bool +type MatcherFunc func(*maxmind.IPInfo) bool + +type Matcher struct { + match MatcherFunc +} type Matchers []Matcher @@ -29,68 +34,69 @@ var errMatcherFormat = gperr.Multiline().AddLines( ) var ( - errSyntax = gperr.New("syntax error") - errInvalidIP = gperr.New("invalid IP") - errInvalidCIDR = gperr.New("invalid CIDR") - errMaxMindNotConfigured = gperr.New("MaxMind not configured") + errSyntax = errors.New("syntax error") + errInvalidIP = errors.New("invalid IP") + errInvalidCIDR = errors.New("invalid CIDR") + errMaxMindNotConfigured = errors.New("MaxMind not configured") ) -func ParseMatcher(s string) (Matcher, gperr.Error) { +func (matcher *Matcher) Parse(s string) error { parts := strings.Split(s, ":") if len(parts) != 2 { - return nil, errSyntax + return errSyntax } switch parts[0] { case MatcherTypeIP: ip := net.ParseIP(parts[1]) if ip == nil { - return nil, errInvalidIP + return errInvalidIP } - return matchIP(ip), nil + matcher.match = matchIP(ip) case MatcherTypeCIDR: _, net, err := net.ParseCIDR(parts[1]) if err != nil { - return nil, errInvalidCIDR + return errInvalidCIDR } - return matchCIDR(net), nil + matcher.match = matchCIDR(net) case MatcherTypeTimeZone: if !maxmind.HasInstance() { - return nil, errMaxMindNotConfigured + return errMaxMindNotConfigured } - return matchTimeZone(parts[1]), nil + matcher.match = matchTimeZone(parts[1]) case MatcherTypeCountry: if !maxmind.HasInstance() { - return nil, errMaxMindNotConfigured + return errMaxMindNotConfigured } - return matchISOCode(parts[1]), nil + matcher.match = matchISOCode(parts[1]) default: - return nil, errSyntax + return errSyntax } + return nil } func (matchers Matchers) Match(ip *maxmind.IPInfo) bool { for _, m := range matchers { - if m(ip) { + if m.match(ip) { return true } } return false } -func matchIP(ip net.IP) Matcher { +func matchIP(ip net.IP) MatcherFunc { return func(ip2 *maxmind.IPInfo) bool { return ip.Equal(ip2.IP) } } -func matchCIDR(n *net.IPNet) Matcher { +func matchCIDR(n *net.IPNet) MatcherFunc { return func(ip *maxmind.IPInfo) bool { return n.Contains(ip.IP) } } -func matchTimeZone(tz string) Matcher { +func matchTimeZone(tz string) MatcherFunc { return func(ip *maxmind.IPInfo) bool { city, ok := maxmind.LookupCity(ip) if !ok { @@ -100,7 +106,7 @@ func matchTimeZone(tz string) Matcher { } } -func matchISOCode(iso string) Matcher { +func matchISOCode(iso string) MatcherFunc { return func(ip *maxmind.IPInfo) bool { city, ok := maxmind.LookupCity(ip) if !ok {