updated validation for middleware options

This commit is contained in:
yusing 2024-11-30 04:00:55 +08:00
parent edc1ad952d
commit 6e9b5cc113
9 changed files with 97 additions and 94 deletions

View file

@ -16,9 +16,9 @@ type cidrWhitelist struct {
} }
type cidrWhitelistOpts struct { type cidrWhitelistOpts struct {
Allow []*types.CIDR `json:"allow"` Allow []*types.CIDR `validate:"min=1"`
StatusCode int `json:"statusCode"` StatusCode int `validate:"omitempty,gte=400,lte=599"`
Message string `json:"message"` Message string
} }
var ( var (
@ -42,9 +42,6 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(wl.cidrWhitelistOpts.Allow) == 0 {
return nil, E.New("no allowed CIDRs")
}
return wl.m, nil return wl.m, nil
} }

View file

@ -2,10 +2,12 @@ package middleware
import ( import (
_ "embed" _ "embed"
"net"
"net/http" "net/http"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -13,6 +15,37 @@ import (
var testCIDRWhitelistCompose []byte var testCIDRWhitelistCompose []byte
var deny, accept *Middleware var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) {
t.Run("valid", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"message": "test-message",
})
ExpectNoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"message": "test-message",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"allow": []string{"1.2.3.4/123"},
"message": "test-message",
})
ExpectErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"status_code": 600,
"message": "test-message",
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestCIDRWhitelist(t *testing.T) { func TestCIDRWhitelist(t *testing.T) {
errs := E.NewBuilder("") errs := E.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
@ -24,6 +57,7 @@ func TestCIDRWhitelist(t *testing.T) {
} }
t.Run("deny", func(t *testing.T) { t.Run("deny", func(t *testing.T) {
t.Parallel()
for range 10 { for range 10 {
result, err := newMiddlewareTest(deny, nil) result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err) ExpectNoError(t, err)
@ -33,6 +67,7 @@ func TestCIDRWhitelist(t *testing.T) {
}) })
t.Run("accept", func(t *testing.T) { t.Run("accept", func(t *testing.T) {
t.Parallel()
for range 10 { for range 10 {
result, err := newMiddlewareTest(accept, nil) result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err) ExpectNoError(t, err)

View file

@ -8,7 +8,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url"
"slices" "slices"
"strings" "strings"
"time" "time"
@ -21,63 +20,48 @@ type (
forwardAuth struct { forwardAuth struct {
forwardAuthOpts forwardAuthOpts
m *Middleware m *Middleware
client http.Client
} }
forwardAuthOpts struct { forwardAuthOpts struct {
Address string `json:"address"` Address string `validate:"url,required"`
TrustForwardHeader bool `json:"trustForwardHeader"` TrustForwardHeader bool
AuthResponseHeaders []string `json:"authResponseHeaders"` AuthResponseHeaders []string
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"` AddAuthCookiesToResponse []string
transport http.RoundTripper
} }
) )
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc} var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
}
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth) fa := new(forwardAuth)
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil { if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err return nil, err
} }
if _, err := url.Parse(fa.Address); err != nil {
return nil, E.From(err)
}
fa.m = &Middleware{ fa.m = &Middleware{
impl: fa, impl: fa,
before: fa.forward, before: fa.forward,
} }
// TODO: use tr from reverse proxy
tr, ok := fa.transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = gphttp.DefaultTransport.Clone()
}
fa.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
return fa.m, nil return fa.m, nil
} }
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gphttp.RemoveHop(req.Header) gphttp.RemoveHop(req.Header)
url := fa.Address
faReq, err := http.NewRequestWithContext( faReq, err := http.NewRequestWithContext(
req.Context(), req.Context(),
http.MethodGet, http.MethodGet,
fa.Address, url,
nil, nil,
) )
if err != nil { if err != nil {
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err) fa.m.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@ -89,9 +73,9 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
fa.setAuthHeaders(req, faReq) fa.setAuthHeaders(req, faReq)
fa.m.AddTraceRequest("forward auth request", faReq) fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := fa.client.Do(faReq) faResp, err := faHTTPClient.Do(faReq)
if err != nil { if err != nil {
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err) fa.m.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@ -99,7 +83,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body) body, err := io.ReadAll(faResp.Body)
if err != nil { if err != nil {
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err) fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@ -111,7 +95,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
redirectURL, err := faResp.Location() redirectURL, err := faResp.Location()
if err != nil { if err != nil {
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp) fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} else if redirectURL.String() != "" { } else if redirectURL.String() != "" {
@ -122,7 +106,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
w.WriteHeader(faResp.StatusCode) w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil { if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp) fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
} }
return return
} }

View file

@ -14,9 +14,9 @@ type (
} }
// order: set_headers -> add_headers -> hide_headers // order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct { modifyRequestOpts struct {
SetHeaders map[string]string `json:"setHeaders"` SetHeaders map[string]string
AddHeaders map[string]string `json:"addHeaders"` AddHeaders map[string]string
HideHeaders []string `json:"hideHeaders"` HideHeaders []string
} }
) )

View file

@ -21,9 +21,9 @@ type (
} }
rateLimiterOpts struct { rateLimiterOpts struct {
Average int `json:"average"` Average int `validate:"min=1,required"`
Burst int `json:"burst"` Burst int `validate:"min=1,required"`
Period time.Duration `json:"period"` Period time.Duration
} }
) )

View file

@ -16,9 +16,9 @@ type realIP struct {
type realIPOpts struct { type realIPOpts struct {
// Header is the name of the header to use for the real client IP // Header is the name of the header to use for the real client IP
Header string `json:"header"` Header string `validate:"required"`
// From is a list of Address / CIDRs to trust // From is a list of Address / CIDRs to trust
From []*types.CIDR `json:"from"` From []*types.CIDR `validate:"min=1"`
/* /*
If recursive search is disabled, If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
@ -27,7 +27,7 @@ type realIPOpts struct {
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field. the last non-trusted address sent in the request header field.
*/ */
Recursive bool `json:"recursive"` Recursive bool
} }
var ( var (

View file

@ -20,3 +20,4 @@ accept:
- use: CIDRWhitelist - use: CIDRWhitelist
allow: allow:
- 192.168.0.0/24 - 192.168.0.0/24
- 127.0.0.1

View file

@ -1,45 +1,7 @@
package types package types
import ( import (
"bytes"
"net" "net"
"strings"
E "github.com/yusing/go-proxy/internal/error"
) )
type CIDR net.IPNet type CIDR = net.IPNet
var (
ErrInvalidCIDR = E.New("invalid CIDR")
ErrInvalidCIDRType = E.New("invalid CIDR type")
)
func (cidr *CIDR) ConvertFrom(val any) E.Error {
cidrStr, ok := val.(string)
if !ok {
return ErrInvalidCIDRType.Subjectf("%T", val)
}
if !strings.Contains(cidrStr, "/") {
cidrStr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidrStr)
if err != nil {
return ErrInvalidCIDR.Subject(cidrStr)
}
*cidr = CIDR(*ipnet)
return nil
}
func (cidr *CIDR) Contains(ip net.IP) bool {
return (*net.IPNet)(cidr).Contains(ip)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}
func (cidr *CIDR) Equals(other *CIDR) bool {
return (*net.IPNet)(cidr).IP.Equal(other.IP) && bytes.Equal(cidr.Mask, other.Mask)
}

View file

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"net"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -32,6 +33,7 @@ var (
ErrMapMissingColon = E.New("map missing colon") ErrMapMissingColon = E.New("map missing colon")
ErrMapTooManyColons = E.New("map too many colons") ErrMapTooManyColons = E.New("map too many colons")
ErrUnknownField = E.New("unknown field") ErrUnknownField = E.New("unknown field")
ErrValidationError = E.New("validation error")
) )
var validate = validator.New() var validate = validator.New()
@ -203,11 +205,23 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(err.Subject(k)) errs.Add(err.Subject(k))
} }
} else { } else {
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dstV.Interface())))) errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
} }
} }
if needValidate { if needValidate {
errs.Add(validate.Struct(dstV.Interface())) err := validate.Struct(dstV.Interface())
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(strutils.ToLowerNoSnake(e.Field())).
Withf("require %q", detail))
}
}
} }
return errs.Error() return errs.Error()
case reflect.Map: case reflect.Map:
@ -337,6 +351,16 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
} }
dst.Set(reflect.ValueOf(d)) dst.Set(reflect.ValueOf(d))
return return
case reflect.TypeFor[net.IPNet]():
if !strings.Contains(src, "/") {
src += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(src)
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(*ipnet))
return
default: default:
} }
// primitive types / simple types // primitive types / simple types