mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-22 12:24:02 +02:00
removed unnecessary pointer indirection, added rate limiter middleware
This commit is contained in:
parent
3bf520541b
commit
8df28628ec
17 changed files with 167 additions and 79 deletions
16
Makefile
16
Makefile
|
@ -28,23 +28,21 @@ get:
|
||||||
go get -u ./cmd && go mod tidy
|
go get -u ./cmd && go mod tidy
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
make build
|
GOPROXY_DEBUG=1 make run
|
||||||
sudo GOPROXY_DEBUG=1 bin/go-proxy
|
|
||||||
|
|
||||||
debug-trace:
|
debug-trace:
|
||||||
make build
|
GOPROXY_DEBUG=1 GOPROXY_TRACE=1 run
|
||||||
sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
|
|
||||||
|
|
||||||
profile:
|
profile:
|
||||||
GODEBUG=gctrace=1 make build
|
GODEBUG=gctrace=1 make debug
|
||||||
sudo GOPROXY_DEBUG=1 bin/go-proxy
|
|
||||||
|
run: build
|
||||||
|
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/go-proxy
|
||||||
|
bin/go-proxy
|
||||||
|
|
||||||
mtrace:
|
mtrace:
|
||||||
bin/go-proxy debug-ls-mtrace > mtrace.json
|
bin/go-proxy debug-ls-mtrace > mtrace.json
|
||||||
|
|
||||||
run:
|
|
||||||
make build && sudo bin/go-proxy
|
|
||||||
|
|
||||||
archive:
|
archive:
|
||||||
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
|
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip
|
||||||
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -15,6 +15,7 @@ require (
|
||||||
github.com/santhosh-tekuri/jsonschema v1.2.4
|
github.com/santhosh-tekuri/jsonschema v1.2.4
|
||||||
golang.org/x/net v0.30.0
|
golang.org/x/net v0.30.0
|
||||||
golang.org/x/text v0.19.0
|
golang.org/x/text v0.19.0
|
||||||
|
golang.org/x/time v0.7.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,7 +58,6 @@ require (
|
||||||
golang.org/x/oauth2 v0.23.0 // indirect
|
golang.org/x/oauth2 v0.23.0 // indirect
|
||||||
golang.org/x/sync v0.8.0 // indirect
|
golang.org/x/sync v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.26.0 // indirect
|
golang.org/x/sys v0.26.0 // indirect
|
||||||
golang.org/x/time v0.7.0 // indirect
|
|
||||||
golang.org/x/tools v0.26.0 // indirect
|
golang.org/x/tools v0.26.0 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
gotest.tools/v3 v3.5.1 // indirect
|
gotest.tools/v3 v3.5.1 // indirect
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type cidrWhitelist struct {
|
type cidrWhitelist struct {
|
||||||
*cidrWhitelistOpts
|
cidrWhitelistOpts
|
||||||
m *Middleware
|
m *Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,18 +22,15 @@ type cidrWhitelistOpts struct {
|
||||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||||
}
|
}
|
||||||
|
|
||||||
var CIDRWhiteList = &cidrWhitelist{
|
var (
|
||||||
m: &Middleware{withOptions: NewCIDRWhitelist},
|
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist}
|
||||||
}
|
cidrWhitelistDefaults = cidrWhitelistOpts{
|
||||||
|
|
||||||
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
|
||||||
return &cidrWhitelistOpts{
|
|
||||||
Allow: []*types.CIDR{},
|
Allow: []*types.CIDR{},
|
||||||
StatusCode: http.StatusForbidden,
|
StatusCode: http.StatusForbidden,
|
||||||
Message: "IP not allowed",
|
Message: "IP not allowed",
|
||||||
cachedAddr: F.NewMapOf[string, bool](),
|
cachedAddr: F.NewMapOf[string, bool](),
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
|
|
||||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
||||||
wl := new(cidrWhitelist)
|
wl := new(cidrWhitelist)
|
||||||
|
@ -41,8 +38,8 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
||||||
impl: wl,
|
impl: wl,
|
||||||
before: wl.checkIP,
|
before: wl.checkIP,
|
||||||
}
|
}
|
||||||
wl.cidrWhitelistOpts = cidrWhitelistDefaults()
|
wl.cidrWhitelistOpts = cidrWhitelistDefaults
|
||||||
err := Deserialize(opts, wl.cidrWhitelistOpts)
|
err := Deserialize(opts, &wl.cidrWhitelistOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,8 +27,8 @@ func TestCIDRWhitelist(t *testing.T) {
|
||||||
for range 10 {
|
for range 10 {
|
||||||
result, err := newMiddlewareTest(deny, nil)
|
result, err := newMiddlewareTest(deny, nil)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
|
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
|
||||||
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
|
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults.Message)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,7 @@ var (
|
||||||
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
|
||||||
)
|
)
|
||||||
|
|
||||||
var CloudflareRealIP = &realIP{
|
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP}
|
||||||
m: &Middleware{withOptions: NewCloudflareRealIP},
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
||||||
cri := new(realIP)
|
cri := new(realIP)
|
||||||
|
@ -46,7 +44,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cri.realIPOpts = &realIPOpts{
|
cri.realIPOpts = realIPOpts{
|
||||||
Header: "CF-Connecting-IP",
|
Header: "CF-Connecting-IP",
|
||||||
Recursive: true,
|
Recursive: true,
|
||||||
}
|
}
|
||||||
|
|
5
internal/net/http/middleware/errors.go
Normal file
5
internal/net/http/middleware/errors.go
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
|
||||||
|
var ErrZeroValue = E.New("cannot be zero")
|
|
@ -19,7 +19,7 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
forwardAuth struct {
|
forwardAuth struct {
|
||||||
*forwardAuthOpts
|
forwardAuthOpts
|
||||||
m *Middleware
|
m *Middleware
|
||||||
client http.Client
|
client http.Client
|
||||||
}
|
}
|
||||||
|
@ -33,14 +33,11 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var ForwardAuth = &forwardAuth{
|
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc}
|
||||||
m: &Middleware{withOptions: NewForwardAuthfunc},
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
fa := new(forwardAuth)
|
fa := new(forwardAuth)
|
||||||
fa.forwardAuthOpts = new(forwardAuthOpts)
|
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
||||||
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err := url.Parse(fa.Address); err != nil {
|
if _, err := url.Parse(fa.Address); err != nil {
|
||||||
|
|
|
@ -28,7 +28,6 @@ type (
|
||||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
||||||
|
|
||||||
OptionsRaw = map[string]any
|
OptionsRaw = map[string]any
|
||||||
Options any
|
|
||||||
|
|
||||||
Middleware struct {
|
Middleware struct {
|
||||||
_ U.NoCopy
|
_ U.NoCopy
|
||||||
|
|
|
@ -35,20 +35,23 @@ func All() map[string]*Middleware {
|
||||||
|
|
||||||
// initialize middleware names and label parsers.
|
// initialize middleware names and label parsers.
|
||||||
func init() {
|
func init() {
|
||||||
|
// snakes and cases will be stripped on `Get`
|
||||||
|
// so keys are lowercase without snake.
|
||||||
allMiddlewares = map[string]*Middleware{
|
allMiddlewares = map[string]*Middleware{
|
||||||
"setxforwarded": SetXForwarded,
|
"setxforwarded": SetXForwarded,
|
||||||
"hidexforwarded": HideXForwarded,
|
"hidexforwarded": HideXForwarded,
|
||||||
"redirecthttp": RedirectHTTP,
|
"redirecthttp": RedirectHTTP,
|
||||||
"modifyresponse": ModifyResponse.m,
|
"modifyresponse": ModifyResponse,
|
||||||
"modifyrequest": ModifyRequest.m,
|
"modifyrequest": ModifyRequest,
|
||||||
"errorpage": CustomErrorPage,
|
"errorpage": CustomErrorPage,
|
||||||
"customerrorpage": CustomErrorPage,
|
"customerrorpage": CustomErrorPage,
|
||||||
"realip": RealIP.m,
|
"realip": RealIP,
|
||||||
"cloudflarerealip": CloudflareRealIP.m,
|
"cloudflarerealip": CloudflareRealIP,
|
||||||
"cidrwhitelist": CIDRWhiteList.m,
|
"cidrwhitelist": CIDRWhiteList,
|
||||||
|
"ratelimit": RateLimiter,
|
||||||
|
|
||||||
// !experimental
|
// !experimental
|
||||||
"forwardauth": ForwardAuth.m,
|
"forwardauth": ForwardAuth,
|
||||||
// "oauth2": OAuth2.m,
|
// "oauth2": OAuth2.m,
|
||||||
}
|
}
|
||||||
names := make(map[*Middleware][]string)
|
names := make(map[*Middleware][]string)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
modifyRequest struct {
|
modifyRequest struct {
|
||||||
*modifyRequestOpts
|
modifyRequestOpts
|
||||||
m *Middleware
|
m *Middleware
|
||||||
}
|
}
|
||||||
// order: set_headers -> add_headers -> hide_headers
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
|
@ -18,9 +18,7 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var ModifyRequest = &modifyRequest{
|
var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
|
||||||
m: &Middleware{withOptions: NewModifyRequest},
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
mr := new(modifyRequest)
|
mr := new(modifyRequest)
|
||||||
|
@ -34,8 +32,7 @@ func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
impl: mr,
|
impl: mr,
|
||||||
before: Rewrite(mrFunc),
|
before: Rewrite(mrFunc),
|
||||||
}
|
}
|
||||||
mr.modifyRequestOpts = new(modifyRequestOpts)
|
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
||||||
err := Deserialize(optsRaw, mr.modifyRequestOpts)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("set_options", func(t *testing.T) {
|
t.Run("set_options", func(t *testing.T) {
|
||||||
mr, err := ModifyRequest.m.WithOptionsClone(opts)
|
mr, err := ModifyRequest.WithOptionsClone(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||||
|
@ -23,7 +23,7 @@ func TestSetModifyRequest(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("request_headers", func(t *testing.T) {
|
t.Run("request_headers", func(t *testing.T) {
|
||||||
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
|
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||||
middlewareOpt: opts,
|
middlewareOpt: opts,
|
||||||
})
|
})
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
|
|
|
@ -9,16 +9,14 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
modifyResponse struct {
|
modifyResponse struct {
|
||||||
*modifyResponseOpts
|
modifyResponseOpts
|
||||||
m *Middleware
|
m *Middleware
|
||||||
}
|
}
|
||||||
// order: set_headers -> add_headers -> hide_headers
|
// order: set_headers -> add_headers -> hide_headers
|
||||||
modifyResponseOpts = modifyRequestOpts
|
modifyResponseOpts = modifyRequestOpts
|
||||||
)
|
)
|
||||||
|
|
||||||
var ModifyResponse = &modifyResponse{
|
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
|
||||||
m: &Middleware{withOptions: NewModifyResponse},
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
mr := new(modifyResponse)
|
mr := new(modifyResponse)
|
||||||
|
@ -28,8 +26,7 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
} else {
|
} else {
|
||||||
mr.m.modifyResponse = mr.modifyResponse
|
mr.m.modifyResponse = mr.modifyResponse
|
||||||
}
|
}
|
||||||
mr.modifyResponseOpts = new(modifyResponseOpts)
|
err := Deserialize(optsRaw, &mr.modifyResponseOpts)
|
||||||
err := Deserialize(optsRaw, mr.modifyResponseOpts)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("set_options", func(t *testing.T) {
|
t.Run("set_options", func(t *testing.T) {
|
||||||
mr, err := ModifyResponse.m.WithOptionsClone(opts)
|
mr, err := ModifyResponse.WithOptionsClone(opts)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||||
|
@ -23,7 +23,7 @@ func TestSetModifyResponse(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("request_headers", func(t *testing.T) {
|
t.Run("request_headers", func(t *testing.T) {
|
||||||
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
|
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
||||||
middlewareOpt: opts,
|
middlewareOpt: opts,
|
||||||
})
|
})
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
|
|
81
internal/net/http/middleware/rate_limit.go
Normal file
81
internal/net/http/middleware/rate_limit.go
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
requestMap = map[string]*rate.Limiter
|
||||||
|
rateLimiter struct {
|
||||||
|
requestMap requestMap
|
||||||
|
newLimiter func() *rate.Limiter
|
||||||
|
m *Middleware
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimiterOpts struct {
|
||||||
|
Average int `json:"average"`
|
||||||
|
Burst int `json:"burst"`
|
||||||
|
Period time.Duration `json:"period"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
RateLimiter = &Middleware{withOptions: NewRateLimiter}
|
||||||
|
rateLimiterOptsDefault = rateLimiterOpts{
|
||||||
|
Average: 100,
|
||||||
|
Burst: 1,
|
||||||
|
Period: time.Second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
|
rl := new(rateLimiter)
|
||||||
|
opts := rateLimiterOptsDefault
|
||||||
|
err := Deserialize(optsRaw, &opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case opts.Average == 0:
|
||||||
|
return nil, ErrZeroValue.Subject("average")
|
||||||
|
case opts.Burst == 0:
|
||||||
|
return nil, ErrZeroValue.Subject("burst")
|
||||||
|
case opts.Period == 0:
|
||||||
|
return nil, ErrZeroValue.Subject("period")
|
||||||
|
}
|
||||||
|
rl.requestMap = make(requestMap, 0)
|
||||||
|
rl.newLimiter = func() *rate.Limiter {
|
||||||
|
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
|
||||||
|
}
|
||||||
|
rl.m = &Middleware{
|
||||||
|
impl: rl,
|
||||||
|
before: rl.limit,
|
||||||
|
}
|
||||||
|
return rl.m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
|
||||||
|
limiter, ok := rl.requestMap[r.RemoteAddr]
|
||||||
|
if !ok {
|
||||||
|
limiter = rl.newLimiter()
|
||||||
|
rl.requestMap[r.RemoteAddr] = limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.mu.Unlock()
|
||||||
|
|
||||||
|
if limiter.Allow() {
|
||||||
|
next(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||||
|
}
|
27
internal/net/http/middleware/rate_limit_test.go
Normal file
27
internal/net/http/middleware/rate_limit_test.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimit(t *testing.T) {
|
||||||
|
opts := OptionsRaw{
|
||||||
|
"average": "10",
|
||||||
|
"burst": "10",
|
||||||
|
"period": "1s",
|
||||||
|
}
|
||||||
|
|
||||||
|
rl, err := NewRateLimiter(opts)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
for range 10 {
|
||||||
|
result, err := newMiddlewareTest(rl, nil)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||||
|
}
|
||||||
|
result, err := newMiddlewareTest(rl, nil)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
|
||||||
|
}
|
|
@ -1,12 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
type (
|
|
||||||
rateLimiter struct {
|
|
||||||
*rateLimiterOpts
|
|
||||||
m *Middleware
|
|
||||||
}
|
|
||||||
|
|
||||||
rateLimiterOpts struct {
|
|
||||||
Count int `json:"count"`
|
|
||||||
}
|
|
||||||
)
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||||
|
|
||||||
type realIP struct {
|
type realIP struct {
|
||||||
*realIPOpts
|
realIPOpts
|
||||||
m *Middleware
|
m *Middleware
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,16 +30,13 @@ type realIPOpts struct {
|
||||||
Recursive bool `json:"recursive"`
|
Recursive bool `json:"recursive"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var RealIP = &realIP{
|
var (
|
||||||
m: &Middleware{withOptions: NewRealIP},
|
RealIP = &Middleware{withOptions: NewRealIP}
|
||||||
}
|
realIPOptsDefault = realIPOpts{
|
||||||
|
|
||||||
var realIPOptsDefault = func() *realIPOpts {
|
|
||||||
return &realIPOpts{
|
|
||||||
Header: "X-Real-IP",
|
Header: "X-Real-IP",
|
||||||
From: []*types.CIDR{},
|
From: []*types.CIDR{},
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
|
|
||||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
|
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
|
||||||
riWithOpts := new(realIP)
|
riWithOpts := new(realIP)
|
||||||
|
@ -47,11 +44,14 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
|
||||||
impl: riWithOpts,
|
impl: riWithOpts,
|
||||||
before: Rewrite(riWithOpts.setRealIP),
|
before: Rewrite(riWithOpts.setRealIP),
|
||||||
}
|
}
|
||||||
riWithOpts.realIPOpts = realIPOptsDefault()
|
riWithOpts.realIPOpts = realIPOptsDefault
|
||||||
err := Deserialize(opts, riWithOpts.realIPOpts)
|
err := Deserialize(opts, &riWithOpts.realIPOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if len(riWithOpts.From) == 0 {
|
||||||
|
return nil, E.New("no allowed CIDRs").Subject("from")
|
||||||
|
}
|
||||||
return riWithOpts.m, nil
|
return riWithOpts.m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,9 +70,10 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
clientIPStr = req.RemoteAddr
|
clientIPStr = req.RemoteAddr
|
||||||
}
|
}
|
||||||
clientIP := net.ParseIP(clientIPStr)
|
|
||||||
|
|
||||||
var isTrusted = false
|
clientIP := net.ParseIP(clientIPStr)
|
||||||
|
isTrusted := false
|
||||||
|
|
||||||
for _, CIDR := range ri.From {
|
for _, CIDR := range ri.From {
|
||||||
if CIDR.Contains(clientIP) {
|
if CIDR.Contains(clientIP) {
|
||||||
isTrusted = true
|
isTrusted = true
|
||||||
|
|
Loading…
Add table
Reference in a new issue