removed unnecessary pointer indirection, added rate limiter middleware

This commit is contained in:
yusing 2024-11-03 07:07:30 +08:00
parent 3bf520541b
commit 8df28628ec
17 changed files with 167 additions and 79 deletions

View file

@ -28,23 +28,21 @@ get:
go get -u ./cmd && go mod tidy
debug:
make build
sudo GOPROXY_DEBUG=1 bin/go-proxy
GOPROXY_DEBUG=1 make run
debug-trace:
make build
sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
GOPROXY_DEBUG=1 GOPROXY_TRACE=1 run
profile:
GODEBUG=gctrace=1 make build
sudo GOPROXY_DEBUG=1 bin/go-proxy
GODEBUG=gctrace=1 make debug
run: build
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/go-proxy
bin/go-proxy
mtrace:
bin/go-proxy debug-ls-mtrace > mtrace.json
run:
make build && sudo bin/go-proxy
archive:
git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip

2
go.mod
View file

@ -15,6 +15,7 @@ require (
github.com/santhosh-tekuri/jsonschema v1.2.4
golang.org/x/net v0.30.0
golang.org/x/text v0.19.0
golang.org/x/time v0.7.0
gopkg.in/yaml.v3 v3.0.1
)
@ -57,7 +58,6 @@ require (
golang.org/x/oauth2 v0.23.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/time v0.7.0 // indirect
golang.org/x/tools v0.26.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gotest.tools/v3 v3.5.1 // indirect

View file

@ -10,7 +10,7 @@ import (
)
type cidrWhitelist struct {
*cidrWhitelistOpts
cidrWhitelistOpts
m *Middleware
}
@ -22,18 +22,15 @@ type cidrWhitelistOpts struct {
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
var CIDRWhiteList = &cidrWhitelist{
m: &Middleware{withOptions: NewCIDRWhitelist},
}
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
return &cidrWhitelistOpts{
var (
CIDRWhiteList = &Middleware{withOptions: NewCIDRWhitelist}
cidrWhitelistDefaults = cidrWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
cachedAddr: F.NewMapOf[string, bool](),
}
}
)
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
wl := new(cidrWhitelist)
@ -41,8 +38,8 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
impl: wl,
before: wl.checkIP,
}
wl.cidrWhitelistOpts = cidrWhitelistDefaults()
err := Deserialize(opts, wl.cidrWhitelistOpts)
wl.cidrWhitelistOpts = cidrWhitelistDefaults
err := Deserialize(opts, &wl.cidrWhitelistOpts)
if err != nil {
return nil, err
}

View file

@ -27,8 +27,8 @@ func TestCIDRWhitelist(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults.StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults.Message)
}
})

View file

@ -29,9 +29,7 @@ var (
cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
)
var CloudflareRealIP = &realIP{
m: &Middleware{withOptions: NewCloudflareRealIP},
}
var CloudflareRealIP = &Middleware{withOptions: NewCloudflareRealIP}
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
cri := new(realIP)
@ -46,7 +44,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
next(w, r)
},
}
cri.realIPOpts = &realIPOpts{
cri.realIPOpts = realIPOpts{
Header: "CF-Connecting-IP",
Recursive: true,
}

View file

@ -0,0 +1,5 @@
package middleware
import E "github.com/yusing/go-proxy/internal/error"
var ErrZeroValue = E.New("cannot be zero")

View file

@ -19,7 +19,7 @@ import (
type (
forwardAuth struct {
*forwardAuthOpts
forwardAuthOpts
m *Middleware
client http.Client
}
@ -33,14 +33,11 @@ type (
}
)
var ForwardAuth = &forwardAuth{
m: &Middleware{withOptions: NewForwardAuthfunc},
}
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc}
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts)
if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err
}
if _, err := url.Parse(fa.Address); err != nil {

View file

@ -28,7 +28,6 @@ type (
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any
Options any
Middleware struct {
_ U.NoCopy

View file

@ -35,20 +35,23 @@ func All() map[string]*Middleware {
// initialize middleware names and label parsers.
func init() {
// snakes and cases will be stripped on `Get`
// so keys are lowercase without snake.
allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"modifyresponse": ModifyResponse,
"modifyrequest": ModifyRequest,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
"cidrwhitelist": CIDRWhiteList.m,
"realip": RealIP,
"cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
// !experimental
"forwardauth": ForwardAuth.m,
"forwardauth": ForwardAuth,
// "oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)

View file

@ -7,7 +7,7 @@ import (
type (
modifyRequest struct {
*modifyRequestOpts
modifyRequestOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
@ -18,9 +18,7 @@ type (
}
)
var ModifyRequest = &modifyRequest{
m: &Middleware{withOptions: NewModifyRequest},
}
var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest)
@ -34,8 +32,7 @@ func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
impl: mr,
before: Rewrite(mrFunc),
}
mr.modifyRequestOpts = new(modifyRequestOpts)
err := Deserialize(optsRaw, mr.modifyRequestOpts)
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil {
return nil, err
}

View file

@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts)
mr, err := ModifyRequest.WithOptionsClone(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
@ -23,7 +23,7 @@ func TestSetModifyRequest(t *testing.T) {
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err)

View file

@ -9,16 +9,14 @@ import (
type (
modifyResponse struct {
*modifyResponseOpts
modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts = modifyRequestOpts
)
var ModifyResponse = &modifyResponse{
m: &Middleware{withOptions: NewModifyResponse},
}
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse)
@ -28,8 +26,7 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
} else {
mr.m.modifyResponse = mr.modifyResponse
}
mr.modifyResponseOpts = new(modifyResponseOpts)
err := Deserialize(optsRaw, mr.modifyResponseOpts)
err := Deserialize(optsRaw, &mr.modifyResponseOpts)
if err != nil {
return nil, err
}

View file

@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts)
mr, err := ModifyResponse.WithOptionsClone(opts)
ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
@ -23,7 +23,7 @@ func TestSetModifyResponse(t *testing.T) {
})
t.Run("request_headers", func(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts,
})
ExpectNoError(t, err)

View file

@ -0,0 +1,81 @@
package middleware
import (
"net/http"
"sync"
"time"
E "github.com/yusing/go-proxy/internal/error"
"golang.org/x/time/rate"
)
type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
requestMap requestMap
newLimiter func() *rate.Limiter
m *Middleware
mu sync.Mutex
}
rateLimiterOpts struct {
Average int `json:"average"`
Burst int `json:"burst"`
Period time.Duration `json:"period"`
}
)
var (
RateLimiter = &Middleware{withOptions: NewRateLimiter}
rateLimiterOptsDefault = rateLimiterOpts{
Average: 100,
Burst: 1,
Period: time.Second,
}
)
func NewRateLimiter(optsRaw OptionsRaw) (*Middleware, E.Error) {
rl := new(rateLimiter)
opts := rateLimiterOptsDefault
err := Deserialize(optsRaw, &opts)
if err != nil {
return nil, err
}
switch {
case opts.Average == 0:
return nil, ErrZeroValue.Subject("average")
case opts.Burst == 0:
return nil, ErrZeroValue.Subject("burst")
case opts.Period == 0:
return nil, ErrZeroValue.Subject("period")
}
rl.requestMap = make(requestMap, 0)
rl.newLimiter = func() *rate.Limiter {
return rate.NewLimiter(rate.Limit(opts.Average)*rate.Every(opts.Period), opts.Burst)
}
rl.m = &Middleware{
impl: rl,
before: rl.limit,
}
return rl.m, nil
}
func (rl *rateLimiter) limit(next http.HandlerFunc, w ResponseWriter, r *Request) {
rl.mu.Lock()
limiter, ok := rl.requestMap[r.RemoteAddr]
if !ok {
limiter = rl.newLimiter()
rl.requestMap[r.RemoteAddr] = limiter
}
rl.mu.Unlock()
if limiter.Allow() {
next(w, r)
return
}
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
}

View file

@ -0,0 +1,27 @@
package middleware
import (
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRateLimit(t *testing.T) {
opts := OptionsRaw{
"average": "10",
"burst": "10",
"period": "1s",
}
rl, err := NewRateLimiter(opts)
ExpectNoError(t, err)
for range 10 {
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
result, err := newMiddlewareTest(rl, nil)
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTooManyRequests)
}

View file

@ -1,12 +0,0 @@
package middleware
type (
rateLimiter struct {
*rateLimiterOpts
m *Middleware
}
rateLimiterOpts struct {
Count int `json:"count"`
}
)

View file

@ -10,7 +10,7 @@ import (
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
type realIP struct {
*realIPOpts
realIPOpts
m *Middleware
}
@ -30,16 +30,13 @@ type realIPOpts struct {
Recursive bool `json:"recursive"`
}
var RealIP = &realIP{
m: &Middleware{withOptions: NewRealIP},
}
var realIPOptsDefault = func() *realIPOpts {
return &realIPOpts{
var (
RealIP = &Middleware{withOptions: NewRealIP}
realIPOptsDefault = realIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{},
}
}
)
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
riWithOpts := new(realIP)
@ -47,11 +44,14 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
impl: riWithOpts,
before: Rewrite(riWithOpts.setRealIP),
}
riWithOpts.realIPOpts = realIPOptsDefault()
err := Deserialize(opts, riWithOpts.realIPOpts)
riWithOpts.realIPOpts = realIPOptsDefault
err := Deserialize(opts, &riWithOpts.realIPOpts)
if err != nil {
return nil, err
}
if len(riWithOpts.From) == 0 {
return nil, E.New("no allowed CIDRs").Subject("from")
}
return riWithOpts.m, nil
}
@ -70,9 +70,10 @@ func (ri *realIP) setRealIP(req *Request) {
if err != nil {
clientIPStr = req.RemoteAddr
}
clientIP := net.ParseIP(clientIPStr)
var isTrusted = false
clientIP := net.ParseIP(clientIPStr)
isTrusted := false
for _, CIDR := range ri.From {
if CIDR.Contains(clientIP) {
isTrusted = true