From 6e9b5cc11333f93cacc3ec7e6c55c02b9be3c1c7 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 30 Nov 2024 04:00:55 +0800 Subject: [PATCH] updated validation for middleware options --- .../net/http/middleware/cidr_whitelist.go | 9 +-- .../http/middleware/cidr_whitelist_test.go | 35 +++++++++++ internal/net/http/middleware/forward_auth.go | 60 +++++++------------ .../net/http/middleware/modify_request.go | 6 +- internal/net/http/middleware/rate_limit.go | 6 +- internal/net/http/middleware/real_ip.go | 6 +- .../test_data/cidr_whitelist_test.yml | 1 + internal/net/types/cidr.go | 40 +------------ internal/utils/serialization.go | 28 ++++++++- 9 files changed, 97 insertions(+), 94 deletions(-) diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index d28c633..1cf6199 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -16,9 +16,9 @@ type cidrWhitelist struct { } type cidrWhitelistOpts struct { - Allow []*types.CIDR `json:"allow"` - StatusCode int `json:"statusCode"` - Message string `json:"message"` + Allow []*types.CIDR `validate:"min=1"` + StatusCode int `validate:"omitempty,gte=400,lte=599"` + Message string } var ( @@ -42,9 +42,6 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { if err != nil { return nil, err } - if len(wl.cidrWhitelistOpts.Allow) == 0 { - return nil, E.New("no allowed CIDRs") - } return wl.m, nil } diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index 3c278d9..c1d5d6c 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -2,10 +2,12 @@ package middleware import ( _ "embed" + "net" "net/http" "testing" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -13,6 +15,37 @@ import ( var testCIDRWhitelistCompose []byte var deny, accept *Middleware +func TestCIDRWhitelistValidation(t *testing.T) { + t.Run("valid", func(t *testing.T) { + _, err := NewCIDRWhitelist(OptionsRaw{ + "allow": []string{"1.2.3.4/32"}, + "message": "test-message", + }) + ExpectNoError(t, err) + }) + t.Run("missing allow", func(t *testing.T) { + _, err := NewCIDRWhitelist(OptionsRaw{ + "message": "test-message", + }) + ExpectError(t, utils.ErrValidationError, err) + }) + t.Run("invalid cidr", func(t *testing.T) { + _, err := NewCIDRWhitelist(OptionsRaw{ + "allow": []string{"1.2.3.4/123"}, + "message": "test-message", + }) + ExpectErrorT[*net.ParseError](t, err) + }) + t.Run("invalid status code", func(t *testing.T) { + _, err := NewCIDRWhitelist(OptionsRaw{ + "allow": []string{"1.2.3.4/32"}, + "status_code": 600, + "message": "test-message", + }) + ExpectError(t, utils.ErrValidationError, err) + }) +} + func TestCIDRWhitelist(t *testing.T) { errs := E.NewBuilder("") mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) @@ -24,6 +57,7 @@ func TestCIDRWhitelist(t *testing.T) { } t.Run("deny", func(t *testing.T) { + t.Parallel() for range 10 { result, err := newMiddlewareTest(deny, nil) ExpectNoError(t, err) @@ -33,6 +67,7 @@ func TestCIDRWhitelist(t *testing.T) { }) t.Run("accept", func(t *testing.T) { + t.Parallel() for range 10 { result, err := newMiddlewareTest(accept, nil) ExpectNoError(t, err) diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index dcaf1d6..7892c77 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "net/url" "slices" "strings" "time" @@ -20,64 +19,49 @@ import ( type ( forwardAuth struct { forwardAuthOpts - m *Middleware - client http.Client + m *Middleware } forwardAuthOpts struct { - Address string `json:"address"` - TrustForwardHeader bool `json:"trustForwardHeader"` - AuthResponseHeaders []string `json:"authResponseHeaders"` - AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"` - - transport http.RoundTripper + Address string `validate:"url,required"` + TrustForwardHeader bool + AuthResponseHeaders []string + AddAuthCookiesToResponse []string } ) -var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc} +var ForwardAuth = &Middleware{withOptions: NewForwardAuth} -func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { +var faHTTPClient = &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(r *Request, via []*Request) error { + return http.ErrUseLastResponse + }, +} + +func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) { fa := new(forwardAuth) if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil { return nil, err } - if _, err := url.Parse(fa.Address); err != nil { - return nil, E.From(err) - } - fa.m = &Middleware{ impl: fa, before: fa.forward, } - - // TODO: use tr from reverse proxy - tr, ok := fa.transport.(*http.Transport) - if ok { - tr = tr.Clone() - } else { - tr = gphttp.DefaultTransport.Clone() - } - - fa.client = http.Client{ - CheckRedirect: func(r *Request, via []*Request) error { - return http.ErrUseLastResponse - }, - Timeout: 30 * time.Second, - Transport: tr, - } return fa.m, nil } func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { gphttp.RemoveHop(req.Header) + url := fa.Address faReq, err := http.NewRequestWithContext( req.Context(), http.MethodGet, - fa.Address, + url, nil, ) if err != nil { - fa.m.AddTracef("new request err to %s", fa.Address).WithError(err) + fa.m.AddTracef("new request err to %s", url).WithError(err) w.WriteHeader(http.StatusInternalServerError) return } @@ -89,9 +73,9 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req fa.setAuthHeaders(req, faReq) fa.m.AddTraceRequest("forward auth request", faReq) - faResp, err := fa.client.Do(faReq) + faResp, err := faHTTPClient.Do(faReq) if err != nil { - fa.m.AddTracef("failed to call %s", fa.Address).WithError(err) + fa.m.AddTracef("failed to call %s", url).WithError(err) w.WriteHeader(http.StatusInternalServerError) return } @@ -99,7 +83,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req body, err := io.ReadAll(faResp.Body) if err != nil { - fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err) + fa.m.AddTracef("failed to read response body from %s", url).WithError(err) w.WriteHeader(http.StatusInternalServerError) return } @@ -111,7 +95,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req redirectURL, err := faResp.Location() if err != nil { - fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp) + fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp) w.WriteHeader(http.StatusInternalServerError) return } else if redirectURL.String() != "" { @@ -122,7 +106,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req w.WriteHeader(faResp.StatusCode) if _, err = w.Write(body); err != nil { - fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp) + fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp) } return } diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 758b62e..89b0622 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -14,9 +14,9 @@ type ( } // order: set_headers -> add_headers -> hide_headers modifyRequestOpts struct { - SetHeaders map[string]string `json:"setHeaders"` - AddHeaders map[string]string `json:"addHeaders"` - HideHeaders []string `json:"hideHeaders"` + SetHeaders map[string]string + AddHeaders map[string]string + HideHeaders []string } ) diff --git a/internal/net/http/middleware/rate_limit.go b/internal/net/http/middleware/rate_limit.go index 8e118fd..eacc770 100644 --- a/internal/net/http/middleware/rate_limit.go +++ b/internal/net/http/middleware/rate_limit.go @@ -21,9 +21,9 @@ type ( } rateLimiterOpts struct { - Average int `json:"average"` - Burst int `json:"burst"` - Period time.Duration `json:"period"` + Average int `validate:"min=1,required"` + Burst int `validate:"min=1,required"` + Period time.Duration } ) diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 40a2beb..bdf0879 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -16,9 +16,9 @@ type realIP struct { type realIPOpts struct { // Header is the name of the header to use for the real client IP - Header string `json:"header"` + Header string `validate:"required"` // From is a list of Address / CIDRs to trust - From []*types.CIDR `json:"from"` + From []*types.CIDR `validate:"min=1"` /* If recursive search is disabled, the original client address that matches one of the trusted addresses is replaced by @@ -27,7 +27,7 @@ type realIPOpts struct { the original client address that matches one of the trusted addresses is replaced by the last non-trusted address sent in the request header field. */ - Recursive bool `json:"recursive"` + Recursive bool } var ( diff --git a/internal/net/http/middleware/test_data/cidr_whitelist_test.yml b/internal/net/http/middleware/test_data/cidr_whitelist_test.yml index 4c414dd..1cd5411 100644 --- a/internal/net/http/middleware/test_data/cidr_whitelist_test.yml +++ b/internal/net/http/middleware/test_data/cidr_whitelist_test.yml @@ -20,3 +20,4 @@ accept: - use: CIDRWhitelist allow: - 192.168.0.0/24 + - 127.0.0.1 diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index 0ee3c98..09638b4 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -1,45 +1,7 @@ package types import ( - "bytes" "net" - "strings" - - E "github.com/yusing/go-proxy/internal/error" ) -type CIDR net.IPNet - -var ( - ErrInvalidCIDR = E.New("invalid CIDR") - ErrInvalidCIDRType = E.New("invalid CIDR type") -) - -func (cidr *CIDR) ConvertFrom(val any) E.Error { - cidrStr, ok := val.(string) - if !ok { - return ErrInvalidCIDRType.Subjectf("%T", val) - } - - if !strings.Contains(cidrStr, "/") { - cidrStr += "/32" // single IP - } - _, ipnet, err := net.ParseCIDR(cidrStr) - if err != nil { - return ErrInvalidCIDR.Subject(cidrStr) - } - *cidr = CIDR(*ipnet) - return nil -} - -func (cidr *CIDR) Contains(ip net.IP) bool { - return (*net.IPNet)(cidr).Contains(ip) -} - -func (cidr *CIDR) String() string { - return (*net.IPNet)(cidr).String() -} - -func (cidr *CIDR) Equals(other *CIDR) bool { - return (*net.IPNet)(cidr).IP.Equal(other.IP) && bytes.Equal(cidr.Mask, other.Mask) -} +type CIDR = net.IPNet diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 2bc3368..f7cb167 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "net" "reflect" "strconv" "strings" @@ -32,6 +33,7 @@ var ( ErrMapMissingColon = E.New("map missing colon") ErrMapTooManyColons = E.New("map too many colons") ErrUnknownField = E.New("unknown field") + ErrValidationError = E.New("validation error") ) var validate = validator.New() @@ -203,11 +205,23 @@ func Deserialize(src SerializedObject, dst any) E.Error { errs.Add(err.Subject(k)) } } else { - errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dstV.Interface())))) + errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping)))) } } if needValidate { - errs.Add(validate.Struct(dstV.Interface())) + err := validate.Struct(dstV.Interface()) + var valErrs validator.ValidationErrors + if errors.As(err, &valErrs) { + for _, e := range valErrs { + detail := e.ActualTag() + if e.Param() != "" { + detail += ":" + e.Param() + } + errs.Add(ErrValidationError. + Subject(strutils.ToLowerNoSnake(e.Field())). + Withf("require %q", detail)) + } + } } return errs.Error() case reflect.Map: @@ -337,6 +351,16 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E } dst.Set(reflect.ValueOf(d)) return + case reflect.TypeFor[net.IPNet](): + if !strings.Contains(src, "/") { + src += "/32" // single IP + } + _, ipnet, err := net.ParseCIDR(src) + if err != nil { + return true, E.From(err) + } + dst.Set(reflect.ValueOf(*ipnet)) + return default: } // primitive types / simple types