mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 12:42:34 +02:00
updated validation for middleware options
This commit is contained in:
parent
edc1ad952d
commit
6e9b5cc113
9 changed files with 97 additions and 94 deletions
|
@ -16,9 +16,9 @@ type cidrWhitelist struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type cidrWhitelistOpts struct {
|
type cidrWhitelistOpts struct {
|
||||||
Allow []*types.CIDR `json:"allow"`
|
Allow []*types.CIDR `validate:"min=1"`
|
||||||
StatusCode int `json:"statusCode"`
|
StatusCode int `validate:"omitempty,gte=400,lte=599"`
|
||||||
Message string `json:"message"`
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -42,9 +42,6 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(wl.cidrWhitelistOpts.Allow) == 0 {
|
|
||||||
return nil, E.New("no allowed CIDRs")
|
|
||||||
}
|
|
||||||
return wl.m, nil
|
return wl.m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,12 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +15,37 @@ import (
|
||||||
var testCIDRWhitelistCompose []byte
|
var testCIDRWhitelistCompose []byte
|
||||||
var deny, accept *Middleware
|
var deny, accept *Middleware
|
||||||
|
|
||||||
|
func TestCIDRWhitelistValidation(t *testing.T) {
|
||||||
|
t.Run("valid", func(t *testing.T) {
|
||||||
|
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||||
|
"allow": []string{"1.2.3.4/32"},
|
||||||
|
"message": "test-message",
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
})
|
||||||
|
t.Run("missing allow", func(t *testing.T) {
|
||||||
|
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||||
|
"message": "test-message",
|
||||||
|
})
|
||||||
|
ExpectError(t, utils.ErrValidationError, err)
|
||||||
|
})
|
||||||
|
t.Run("invalid cidr", func(t *testing.T) {
|
||||||
|
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||||
|
"allow": []string{"1.2.3.4/123"},
|
||||||
|
"message": "test-message",
|
||||||
|
})
|
||||||
|
ExpectErrorT[*net.ParseError](t, err)
|
||||||
|
})
|
||||||
|
t.Run("invalid status code", func(t *testing.T) {
|
||||||
|
_, err := NewCIDRWhitelist(OptionsRaw{
|
||||||
|
"allow": []string{"1.2.3.4/32"},
|
||||||
|
"status_code": 600,
|
||||||
|
"message": "test-message",
|
||||||
|
})
|
||||||
|
ExpectError(t, utils.ErrValidationError, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestCIDRWhitelist(t *testing.T) {
|
func TestCIDRWhitelist(t *testing.T) {
|
||||||
errs := E.NewBuilder("")
|
errs := E.NewBuilder("")
|
||||||
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
|
||||||
|
@ -24,6 +57,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("deny", func(t *testing.T) {
|
t.Run("deny", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
for range 10 {
|
for range 10 {
|
||||||
result, err := newMiddlewareTest(deny, nil)
|
result, err := newMiddlewareTest(deny, nil)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
|
@ -33,6 +67,7 @@ func TestCIDRWhitelist(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("accept", func(t *testing.T) {
|
t.Run("accept", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
for range 10 {
|
for range 10 {
|
||||||
result, err := newMiddlewareTest(accept, nil)
|
result, err := newMiddlewareTest(accept, nil)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -20,64 +19,49 @@ import (
|
||||||
type (
|
type (
|
||||||
forwardAuth struct {
|
forwardAuth struct {
|
||||||
forwardAuthOpts
|
forwardAuthOpts
|
||||||
m *Middleware
|
m *Middleware
|
||||||
client http.Client
|
|
||||||
}
|
}
|
||||||
forwardAuthOpts struct {
|
forwardAuthOpts struct {
|
||||||
Address string `json:"address"`
|
Address string `validate:"url,required"`
|
||||||
TrustForwardHeader bool `json:"trustForwardHeader"`
|
TrustForwardHeader bool
|
||||||
AuthResponseHeaders []string `json:"authResponseHeaders"`
|
AuthResponseHeaders []string
|
||||||
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
|
AddAuthCookiesToResponse []string
|
||||||
|
|
||||||
transport http.RoundTripper
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc}
|
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
|
||||||
|
|
||||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
var faHTTPClient = &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
CheckRedirect: func(r *Request, via []*Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
fa := new(forwardAuth)
|
fa := new(forwardAuth)
|
||||||
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err := url.Parse(fa.Address); err != nil {
|
|
||||||
return nil, E.From(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fa.m = &Middleware{
|
fa.m = &Middleware{
|
||||||
impl: fa,
|
impl: fa,
|
||||||
before: fa.forward,
|
before: fa.forward,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use tr from reverse proxy
|
|
||||||
tr, ok := fa.transport.(*http.Transport)
|
|
||||||
if ok {
|
|
||||||
tr = tr.Clone()
|
|
||||||
} else {
|
|
||||||
tr = gphttp.DefaultTransport.Clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
fa.client = http.Client{
|
|
||||||
CheckRedirect: func(r *Request, via []*Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
},
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
Transport: tr,
|
|
||||||
}
|
|
||||||
return fa.m, nil
|
return fa.m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||||
gphttp.RemoveHop(req.Header)
|
gphttp.RemoveHop(req.Header)
|
||||||
|
|
||||||
|
url := fa.Address
|
||||||
faReq, err := http.NewRequestWithContext(
|
faReq, err := http.NewRequestWithContext(
|
||||||
req.Context(),
|
req.Context(),
|
||||||
http.MethodGet,
|
http.MethodGet,
|
||||||
fa.Address,
|
url,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err)
|
fa.m.AddTracef("new request err to %s", url).WithError(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -89,9 +73,9 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
fa.setAuthHeaders(req, faReq)
|
fa.setAuthHeaders(req, faReq)
|
||||||
fa.m.AddTraceRequest("forward auth request", faReq)
|
fa.m.AddTraceRequest("forward auth request", faReq)
|
||||||
|
|
||||||
faResp, err := fa.client.Do(faReq)
|
faResp, err := faHTTPClient.Do(faReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err)
|
fa.m.AddTracef("failed to call %s", url).WithError(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -99,7 +83,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
|
|
||||||
body, err := io.ReadAll(faResp.Body)
|
body, err := io.ReadAll(faResp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err)
|
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -111,7 +95,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
|
|
||||||
redirectURL, err := faResp.Location()
|
redirectURL, err := faResp.Location()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp)
|
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
} else if redirectURL.String() != "" {
|
} else if redirectURL.String() != "" {
|
||||||
|
@ -122,7 +106,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
w.WriteHeader(faResp.StatusCode)
|
w.WriteHeader(faResp.StatusCode)
|
||||||
|
|
||||||
if _, err = w.Write(body); err != nil {
|
if _, err = w.Write(body); err != nil {
|
||||||
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp)
|
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,9 +14,9 @@ type (
|
||||||
}
|
}
|
||||||
// order: set_headers -> add_headers -> hide_headers
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
modifyRequestOpts struct {
|
modifyRequestOpts struct {
|
||||||
SetHeaders map[string]string `json:"setHeaders"`
|
SetHeaders map[string]string
|
||||||
AddHeaders map[string]string `json:"addHeaders"`
|
AddHeaders map[string]string
|
||||||
HideHeaders []string `json:"hideHeaders"`
|
HideHeaders []string
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,9 +21,9 @@ type (
|
||||||
}
|
}
|
||||||
|
|
||||||
rateLimiterOpts struct {
|
rateLimiterOpts struct {
|
||||||
Average int `json:"average"`
|
Average int `validate:"min=1,required"`
|
||||||
Burst int `json:"burst"`
|
Burst int `validate:"min=1,required"`
|
||||||
Period time.Duration `json:"period"`
|
Period time.Duration
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,9 @@ type realIP struct {
|
||||||
|
|
||||||
type realIPOpts struct {
|
type realIPOpts struct {
|
||||||
// Header is the name of the header to use for the real client IP
|
// Header is the name of the header to use for the real client IP
|
||||||
Header string `json:"header"`
|
Header string `validate:"required"`
|
||||||
// From is a list of Address / CIDRs to trust
|
// From is a list of Address / CIDRs to trust
|
||||||
From []*types.CIDR `json:"from"`
|
From []*types.CIDR `validate:"min=1"`
|
||||||
/*
|
/*
|
||||||
If recursive search is disabled,
|
If recursive search is disabled,
|
||||||
the original client address that matches one of the trusted addresses is replaced by
|
the original client address that matches one of the trusted addresses is replaced by
|
||||||
|
@ -27,7 +27,7 @@ type realIPOpts struct {
|
||||||
the original client address that matches one of the trusted addresses is replaced by
|
the original client address that matches one of the trusted addresses is replaced by
|
||||||
the last non-trusted address sent in the request header field.
|
the last non-trusted address sent in the request header field.
|
||||||
*/
|
*/
|
||||||
Recursive bool `json:"recursive"`
|
Recursive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -20,3 +20,4 @@ accept:
|
||||||
- use: CIDRWhitelist
|
- use: CIDRWhitelist
|
||||||
allow:
|
allow:
|
||||||
- 192.168.0.0/24
|
- 192.168.0.0/24
|
||||||
|
- 127.0.0.1
|
||||||
|
|
|
@ -1,45 +1,7 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CIDR net.IPNet
|
type CIDR = net.IPNet
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidCIDR = E.New("invalid CIDR")
|
|
||||||
ErrInvalidCIDRType = E.New("invalid CIDR type")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (cidr *CIDR) ConvertFrom(val any) E.Error {
|
|
||||||
cidrStr, ok := val.(string)
|
|
||||||
if !ok {
|
|
||||||
return ErrInvalidCIDRType.Subjectf("%T", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(cidrStr, "/") {
|
|
||||||
cidrStr += "/32" // single IP
|
|
||||||
}
|
|
||||||
_, ipnet, err := net.ParseCIDR(cidrStr)
|
|
||||||
if err != nil {
|
|
||||||
return ErrInvalidCIDR.Subject(cidrStr)
|
|
||||||
}
|
|
||||||
*cidr = CIDR(*ipnet)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cidr *CIDR) Contains(ip net.IP) bool {
|
|
||||||
return (*net.IPNet)(cidr).Contains(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cidr *CIDR) String() string {
|
|
||||||
return (*net.IPNet)(cidr).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cidr *CIDR) Equals(other *CIDR) bool {
|
|
||||||
return (*net.IPNet)(cidr).IP.Equal(other.IP) && bytes.Equal(cidr.Mask, other.Mask)
|
|
||||||
}
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -32,6 +33,7 @@ var (
|
||||||
ErrMapMissingColon = E.New("map missing colon")
|
ErrMapMissingColon = E.New("map missing colon")
|
||||||
ErrMapTooManyColons = E.New("map too many colons")
|
ErrMapTooManyColons = E.New("map too many colons")
|
||||||
ErrUnknownField = E.New("unknown field")
|
ErrUnknownField = E.New("unknown field")
|
||||||
|
ErrValidationError = E.New("validation error")
|
||||||
)
|
)
|
||||||
|
|
||||||
var validate = validator.New()
|
var validate = validator.New()
|
||||||
|
@ -203,11 +205,23 @@ func Deserialize(src SerializedObject, dst any) E.Error {
|
||||||
errs.Add(err.Subject(k))
|
errs.Add(err.Subject(k))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dstV.Interface()))))
|
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if needValidate {
|
if needValidate {
|
||||||
errs.Add(validate.Struct(dstV.Interface()))
|
err := validate.Struct(dstV.Interface())
|
||||||
|
var valErrs validator.ValidationErrors
|
||||||
|
if errors.As(err, &valErrs) {
|
||||||
|
for _, e := range valErrs {
|
||||||
|
detail := e.ActualTag()
|
||||||
|
if e.Param() != "" {
|
||||||
|
detail += ":" + e.Param()
|
||||||
|
}
|
||||||
|
errs.Add(ErrValidationError.
|
||||||
|
Subject(strutils.ToLowerNoSnake(e.Field())).
|
||||||
|
Withf("require %q", detail))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
|
@ -337,6 +351,16 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E
|
||||||
}
|
}
|
||||||
dst.Set(reflect.ValueOf(d))
|
dst.Set(reflect.ValueOf(d))
|
||||||
return
|
return
|
||||||
|
case reflect.TypeFor[net.IPNet]():
|
||||||
|
if !strings.Contains(src, "/") {
|
||||||
|
src += "/32" // single IP
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(src)
|
||||||
|
if err != nil {
|
||||||
|
return true, E.From(err)
|
||||||
|
}
|
||||||
|
dst.Set(reflect.ValueOf(*ipnet))
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
// primitive types / simple types
|
// primitive types / simple types
|
||||||
|
|
Loading…
Add table
Reference in a new issue