mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 12:42:34 +02:00
updated validation for middleware options
This commit is contained in:
parent
edc1ad952d
commit
6e9b5cc113
9 changed files with 97 additions and 94 deletions
|
@ -16,9 +16,9 @@ type cidrWhitelist struct {
|
|||
}
|
||||
|
||||
type cidrWhitelistOpts struct {
|
||||
Allow []*types.CIDR `json:"allow"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
Message string `json:"message"`
|
||||
Allow []*types.CIDR `validate:"min=1"`
|
||||
StatusCode int `validate:"omitempty,gte=400,lte=599"`
|
||||
Message string
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -42,9 +42,6 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(wl.cidrWhitelistOpts.Allow) == 0 {
|
||||
return nil, E.New("no allowed CIDRs")
|
||||
}
|
||||
return wl.m, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,10 +2,12 @@ package middleware
|
|||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -13,6 +15,37 @@ import (
|
|||
var testCIDRWhitelistCompose []byte
|
||||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelistValidation(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
"allow": []string{"1.2.3.4/32"},
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
t.Run("missing allow", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid cidr", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
"allow": []string{"1.2.3.4/123"},
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectErrorT[*net.ParseError](t, err)
|
||||
})
|
||||
t.Run("invalid status code", func(t *testing.T) {
|
||||
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||
"allow": []string{"1.2.3.4/32"},
|
||||
"status_code": 600,
|
||||
"message": "test-message",
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
errs := E.NewBuilder("")
|
||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||
|
@ -24,6 +57,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err)
|
||||
|
@ -33,6 +67,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("accept", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err)
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -20,64 +19,49 @@ import (
|
|||
type (
|
||||
forwardAuth struct {
|
||||
forwardAuthOpts
|
||||
m *Middleware
|
||||
client http.Client
|
||||
m *Middleware
|
||||
}
|
||||
forwardAuthOpts struct {
|
||||
Address string `json:"address"`
|
||||
TrustForwardHeader bool `json:"trustForwardHeader"`
|
||||
AuthResponseHeaders []string `json:"authResponseHeaders"`
|
||||
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
|
||||
|
||||
transport http.RoundTripper
|
||||
Address string `validate:"url,required"`
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
AddAuthCookiesToResponse []string
|
||||
}
|
||||
)
|
||||
|
||||
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc}
|
||||
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
|
||||
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
var faHTTPClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: func(r *Request, via []*Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
fa := new(forwardAuth)
|
||||
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := url.Parse(fa.Address); err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
|
||||
fa.m = &Middleware{
|
||||
impl: fa,
|
||||
before: fa.forward,
|
||||
}
|
||||
|
||||
// TODO: use tr from reverse proxy
|
||||
tr, ok := fa.transport.(*http.Transport)
|
||||
if ok {
|
||||
tr = tr.Clone()
|
||||
} else {
|
||||
tr = gphttp.DefaultTransport.Clone()
|
||||
}
|
||||
|
||||
fa.client = http.Client{
|
||||
CheckRedirect: func(r *Request, via []*Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: tr,
|
||||
}
|
||||
return fa.m, nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
gphttp.RemoveHop(req.Header)
|
||||
|
||||
url := fa.Address
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
req.Context(),
|
||||
http.MethodGet,
|
||||
fa.Address,
|
||||
url,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err)
|
||||
fa.m.AddTracef("new request err to %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -89,9 +73,9 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
fa.setAuthHeaders(req, faReq)
|
||||
fa.m.AddTraceRequest("forward auth request", faReq)
|
||||
|
||||
faResp, err := fa.client.Do(faReq)
|
||||
faResp, err := faHTTPClient.Do(faReq)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err)
|
||||
fa.m.AddTracef("failed to call %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -99,7 +83,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err)
|
||||
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -111,7 +95,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp)
|
||||
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
|
@ -122,7 +106,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp)
|
||||
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -14,9 +14,9 @@ type (
|
|||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyRequestOpts struct {
|
||||
SetHeaders map[string]string `json:"setHeaders"`
|
||||
AddHeaders map[string]string `json:"addHeaders"`
|
||||
HideHeaders []string `json:"hideHeaders"`
|
||||
SetHeaders map[string]string
|
||||
AddHeaders map[string]string
|
||||
HideHeaders []string
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -21,9 +21,9 @@ type (
|
|||
}
|
||||
|
||||
rateLimiterOpts struct {
|
||||
Average int `json:"average"`
|
||||
Burst int `json:"burst"`
|
||||
Period time.Duration `json:"period"`
|
||||
Average int `validate:"min=1,required"`
|
||||
Burst int `validate:"min=1,required"`
|
||||
Period time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -16,9 +16,9 @@ type realIP struct {
|
|||
|
||||
type realIPOpts struct {
|
||||
// Header is the name of the header to use for the real client IP
|
||||
Header string `json:"header"`
|
||||
Header string `validate:"required"`
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*types.CIDR `json:"from"`
|
||||
From []*types.CIDR `validate:"min=1"`
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
|
@ -27,7 +27,7 @@ type realIPOpts struct {
|
|||
the original client address that matches one of the trusted addresses is replaced by
|
||||
the last non-trusted address sent in the request header field.
|
||||
*/
|
||||
Recursive bool `json:"recursive"`
|
||||
Recursive bool
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -20,3 +20,4 @@ accept:
|
|||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
- 127.0.0.1
|
||||
|
|
|
@ -1,45 +1,7 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type CIDR net.IPNet
|
||||
|
||||
var (
|
||||
ErrInvalidCIDR = E.New("invalid CIDR")
|
||||
ErrInvalidCIDRType = E.New("invalid CIDR type")
|
||||
)
|
||||
|
||||
func (cidr *CIDR) ConvertFrom(val any) E.Error {
|
||||
cidrStr, ok := val.(string)
|
||||
if !ok {
|
||||
return ErrInvalidCIDRType.Subjectf("%T", val)
|
||||
}
|
||||
|
||||
if !strings.Contains(cidrStr, "/") {
|
||||
cidrStr += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidrStr)
|
||||
if err != nil {
|
||||
return ErrInvalidCIDR.Subject(cidrStr)
|
||||
}
|
||||
*cidr = CIDR(*ipnet)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Contains(ip net.IP) bool {
|
||||
return (*net.IPNet)(cidr).Contains(ip)
|
||||
}
|
||||
|
||||
func (cidr *CIDR) String() string {
|
||||
return (*net.IPNet)(cidr).String()
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Equals(other *CIDR) bool {
|
||||
return (*net.IPNet)(cidr).IP.Equal(other.IP) && bytes.Equal(cidr.Mask, other.Mask)
|
||||
}
|
||||
type CIDR = net.IPNet
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -32,6 +33,7 @@ var (
|
|||
ErrMapMissingColon = E.New("map missing colon")
|
||||
ErrMapTooManyColons = E.New("map too many colons")
|
||||
ErrUnknownField = E.New("unknown field")
|
||||
ErrValidationError = E.New("validation error")
|
||||
)
|
||||
|
||||
var validate = validator.New()
|
||||
|
@ -203,11 +205,23 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
|||
errs.Add(err.Subject(k))
|
||||
}
|
||||
} else {
|
||||
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dstV.Interface()))))
|
||||
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
|
||||
}
|
||||
}
|
||||
if needValidate {
|
||||
errs.Add(validate.Struct(dstV.Interface()))
|
||||
err := validate.Struct(dstV.Interface())
|
||||
var valErrs validator.ValidationErrors
|
||||
if errors.As(err, &valErrs) {
|
||||
for _, e := range valErrs {
|
||||
detail := e.ActualTag()
|
||||
if e.Param() != "" {
|
||||
detail += ":" + e.Param()
|
||||
}
|
||||
errs.Add(ErrValidationError.
|
||||
Subject(strutils.ToLowerNoSnake(e.Field())).
|
||||
Withf("require %q", detail))
|
||||
}
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
case reflect.Map:
|
||||
|
@ -337,6 +351,16 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
|
|||
}
|
||||
dst.Set(reflect.ValueOf(d))
|
||||
return
|
||||
case reflect.TypeFor[net.IPNet]():
|
||||
if !strings.Contains(src, "/") {
|
||||
src += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(src)
|
||||
if err != nil {
|
||||
return true, E.From(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(*ipnet))
|
||||
return
|
||||
default:
|
||||
}
|
||||
// primitive types / simple types
|
||||
|
|
Loading…
Add table
Reference in a new issue