updated validation for middleware options

This commit is contained in:
yusing 2024-11-30 04:00:55 +08:00
parent edc1ad952d
commit 6e9b5cc113
9 changed files with 97 additions and 94 deletions

View file

@ -16,9 +16,9 @@ type cidrWhitelist struct {
}
type cidrWhitelistOpts struct {
Allow []*types.CIDR `json:"allow"`
StatusCode int `json:"statusCode"`
Message string `json:"message"`
Allow []*types.CIDR `validate:"min=1"`
StatusCode int `validate:"omitempty,gte=400,lte=599"`
Message string
}
var (
@ -42,9 +42,6 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
if err != nil {
return nil, err
}
if len(wl.cidrWhitelistOpts.Allow) == 0 {
return nil, E.New("no allowed CIDRs")
}
return wl.m, nil
}

View file

@ -2,10 +2,12 @@ package middleware
import (
_ "embed"
"net"
"net/http"
"testing"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -13,6 +15,37 @@ import (
var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelistValidation(t *testing.T) {
t.Run("valid", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"message": "test-message",
})
ExpectNoError(t, err)
})
t.Run("missing allow", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"message": "test-message",
})
ExpectError(t, utils.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"allow": []string{"1.2.3.4/123"},
"message": "test-message",
})
ExpectErrorT[*net.ParseError](t, err)
})
t.Run("invalid status code", func(t *testing.T) {
_, err := NewCIDRWhitelist(OptionsRaw{
"allow": []string{"1.2.3.4/32"},
"status_code": 600,
"message": "test-message",
})
ExpectError(t, utils.ErrValidationError, err)
})
}
func TestCIDRWhitelist(t *testing.T) {
errs := E.NewBuilder("")
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
@ -24,6 +57,7 @@ func TestCIDRWhitelist(t *testing.T) {
}
t.Run("deny", func(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err)
@ -33,6 +67,7 @@ func TestCIDRWhitelist(t *testing.T) {
})
t.Run("accept", func(t *testing.T) {
t.Parallel()
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err)

View file

@ -8,7 +8,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
@ -21,63 +20,48 @@ type (
forwardAuth struct {
forwardAuthOpts
m *Middleware
client http.Client
}
forwardAuthOpts struct {
Address string `json:"address"`
TrustForwardHeader bool `json:"trustForwardHeader"`
AuthResponseHeaders []string `json:"authResponseHeaders"`
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
transport http.RoundTripper
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc}
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
}
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err
}
if _, err := url.Parse(fa.Address); err != nil {
return nil, E.From(err)
}
fa.m = &Middleware{
impl: fa,
before: fa.forward,
}
// TODO: use tr from reverse proxy
tr, ok := fa.transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = gphttp.DefaultTransport.Clone()
}
fa.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
return fa.m, nil
}
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gphttp.RemoveHop(req.Header)
url := fa.Address
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
fa.Address,
url,
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err)
fa.m.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -89,9 +73,9 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
fa.setAuthHeaders(req, faReq)
fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := fa.client.Do(faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err)
fa.m.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -99,7 +83,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err)
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -111,7 +95,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp)
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
@ -122,7 +106,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp)
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}

View file

@ -14,9 +14,9 @@ type (
}
// order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct {
SetHeaders map[string]string `json:"setHeaders"`
AddHeaders map[string]string `json:"addHeaders"`
HideHeaders []string `json:"hideHeaders"`
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
)

View file

@ -21,9 +21,9 @@ type (
}
rateLimiterOpts struct {
Average int `json:"average"`
Burst int `json:"burst"`
Period time.Duration `json:"period"`
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration
}
)

View file

@ -16,9 +16,9 @@ type realIP struct {
type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string `json:"header"`
Header string `validate:"required"`
// From is a list of Address / CIDRs to trust
From []*types.CIDR `json:"from"`
From []*types.CIDR `validate:"min=1"`
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@ -27,7 +27,7 @@ type realIPOpts struct {
the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field.
*/
Recursive bool `json:"recursive"`
Recursive bool
}
var (

View file

@ -20,3 +20,4 @@ accept:
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
- 127.0.0.1

View file

@ -1,45 +1,7 @@
package types
import (
"bytes"
"net"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type CIDR net.IPNet
var (
ErrInvalidCIDR = E.New("invalid CIDR")
ErrInvalidCIDRType = E.New("invalid CIDR type")
)
func (cidr *CIDR) ConvertFrom(val any) E.Error {
cidrStr, ok := val.(string)
if !ok {
return ErrInvalidCIDRType.Subjectf("%T", val)
}
if !strings.Contains(cidrStr, "/") {
cidrStr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidrStr)
if err != nil {
return ErrInvalidCIDR.Subject(cidrStr)
}
*cidr = CIDR(*ipnet)
return nil
}
func (cidr *CIDR) Contains(ip net.IP) bool {
return (*net.IPNet)(cidr).Contains(ip)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}
func (cidr *CIDR) Equals(other *CIDR) bool {
return (*net.IPNet)(cidr).IP.Equal(other.IP) && bytes.Equal(cidr.Mask, other.Mask)
}
type CIDR = net.IPNet

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"net"
"reflect"
"strconv"
"strings"
@ -32,6 +33,7 @@ var (
ErrMapMissingColon = E.New("map missing colon")
ErrMapTooManyColons = E.New("map too many colons")
ErrUnknownField = E.New("unknown field")
ErrValidationError = E.New("validation error")
)
var validate = validator.New()
@ -203,11 +205,23 @@ func Deserialize(src SerializedObject, dst any) E.Error {
errs.Add(err.Subject(k))
}
} else {
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dstV.Interface()))))
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
}
}
if needValidate {
errs.Add(validate.Struct(dstV.Interface()))
err := validate.Struct(dstV.Interface())
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(strutils.ToLowerNoSnake(e.Field())).
Withf("require %q", detail))
}
}
}
return errs.Error()
case reflect.Map:
@ -337,6 +351,16 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
}
dst.Set(reflect.ValueOf(d))
return
case reflect.TypeFor[net.IPNet]():
if !strings.Contains(src, "/") {
src += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(src)
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(*ipnet))
return
default:
}
// primitive types / simple types