feat(middlewares): middleware bypass rules

This commit is contained in:
yusing 2025-05-05 18:01:07 +08:00
parent 75db09b1f3
commit ad60f377ba
4 changed files with 221 additions and 19 deletions

View file

@ -0,0 +1,53 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/route/rules"
)
type Bypass []rules.RuleOn
func (b Bypass) ShouldBypass(r *http.Request) bool {
cached := rules.NewCache()
defer cached.Release()
for _, rule := range b {
if rule.Check(cached, r) {
return true
}
}
return false
}
type checkBypass struct {
bypass Bypass
modReq RequestModifier
modRes ResponseModifier
}
func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
if c.modReq == nil || c.bypass.ShouldBypass(r) {
return true
}
return c.modReq.before(w, r)
}
func (c *checkBypass) modifyResponse(resp *http.Response) error {
if c.modRes == nil || c.bypass.ShouldBypass(resp.Request) {
return nil
}
return c.modRes.modifyResponse(resp)
}
func (m *Middleware) withCheckBypass() any {
if len(m.Bypass) > 0 {
modReq, _ := m.impl.(RequestModifier)
modRes, _ := m.impl.(ResponseModifier)
return &checkBypass{
bypass: m.Bypass,
modReq: modReq,
modRes: modRes,
}
}
return m.impl
}

View file

@ -0,0 +1,131 @@
package middleware_test
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func noOpHandler(w http.ResponseWriter, r *http.Request) {}
func TestBypassCIDR(t *testing.T) {
mr, err := ModifyRequest.New(map[string]any{
"set_headers": map[string]string{
"Test-Header": "test-value",
},
"bypass": []string{"remote 127.0.0.1/32"},
})
expect.NoError(t, err)
tests := []struct {
name string
remoteAddr string
expectBypass bool
}{
{"bypass", "127.0.0.1:8080", true},
{"no_bypass", "192.168.1.1:8080", false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
req.RemoteAddr = test.remoteAddr
recorder := httptest.NewRecorder()
mr.ModifyRequest(noOpHandler, recorder, req)
expect.NoError(t, err)
if test.expectBypass {
expect.Equal(t, req.Header.Get("Test-Header"), "")
} else {
expect.Equal(t, req.Header.Get("Test-Header"), "test-value")
}
})
}
}
func TestBypassPath(t *testing.T) {
mr, err := ModifyRequest.New(map[string]any{
"bypass": []string{"path /test/*", "path /api"},
"set_headers": map[string]string{
"Test-Header": "test-value",
},
})
expect.NoError(t, err)
tests := []struct {
name string
path string
expectBypass bool
}{
{"bypass", "/test/123", true},
{"bypass2", "/test/123/456", true},
{"bypass3", "/api", true},
{"no_bypass", "/test1/123/456", false},
{"no_bypass2", "/api/123", false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
recorder := httptest.NewRecorder()
mr.ModifyRequest(noOpHandler, recorder, req)
expect.NoError(t, err)
if test.expectBypass {
expect.Equal(t, req.Header.Get("Test-Header"), "")
} else {
expect.Equal(t, req.Header.Get("Test-Header"), "test-value")
}
})
}
}
type fakeRoundTripper struct{}
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Request: req,
Header: make(http.Header),
}, nil
}
func TestReverseProxyBypass(t *testing.T) {
rp := reverseproxy.NewReverseProxy("test", types.MustParseURL("http://example.com"), fakeRoundTripper{})
err := PatchReverseProxy(rp, map[string]OptionsRaw{
"response": {
"bypass": "path /test/* | path /api",
"set_headers": map[string]string{
"Test-Header": "test-value",
},
},
})
expect.NoError(t, err)
tests := []struct {
name string
path string
expectBypass bool
}{
{"bypass", "/test/123", true},
{"bypass2", "/test/123/456", true},
{"bypass3", "/api", true},
{"no_bypass", "/test1/123/456", false},
{"no_bypass2", "/api/123", false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
recorder := httptest.NewRecorder()
rp.ServeHTTP(recorder, req)
if test.expectBypass {
expect.Equal(t, recorder.Header().Get("Test-Header"), "")
} else {
expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value")
}
})
}
}

View file

@ -2,6 +2,7 @@ package middleware
import (
"encoding/json"
"maps"
"net/http"
"reflect"
"sort"
@ -23,16 +24,22 @@ type (
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
name string
construct ImplNewFunc
impl any
commonOptions = struct {
// priority is only applied for ReverseProxy.
//
// Middleware compose follows the order of the slice
//
// Default is 10, 0 is the highest
priority int
Priority int `json:"priority"`
Bypass Bypass `json:"bypass"`
}
Middleware struct {
name string
construct ImplNewFunc
impl any
commonOptions
}
ByPriority []*Middleware
@ -55,7 +62,7 @@ const DefaultPriority = 10
func (m ByPriority) Len() int { return len(m) }
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
func (m ByPriority) Less(i, j int) bool { return m[i].Priority < m[j].Priority }
func NewMiddleware[ImplType any]() *Middleware {
// type check
@ -107,21 +114,22 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
if len(optsRaw) == 0 {
return nil
}
priority, ok := optsRaw["priority"].(int)
if ok {
m.priority = priority
// remove priority for deserialization, restore later
delete(optsRaw, "priority")
defer func() {
optsRaw["priority"] = priority
}()
} else {
m.priority = DefaultPriority
commonOpts := map[string]any{
"priority": optsRaw["priority"],
"bypass": optsRaw["bypass"],
}
if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
return err
}
optsRaw = maps.Clone(optsRaw)
for k := range commonOpts {
delete(optsRaw, k)
}
return utils.MapUnmarshalValidate(optsRaw, m.impl)
}
func (m *Middleware) finalize() error {
m.impl = m.withCheckBypass()
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
return nil
@ -159,10 +167,16 @@ func (m *Middleware) String() string {
}
func (m *Middleware) MarshalJSON() ([]byte, error) {
type allOptions struct {
commonOptions
any
}
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
"priority": m.priority,
"name": m.name,
"options": allOptions{
commonOptions: m.commonOptions,
any: m.impl,
},
}, "", " ")
}

View file

@ -16,6 +16,10 @@ type RuleOn struct {
checker Checker
}
func (on *RuleOn) Check(cached Cache, r *http.Request) bool {
return on.checker.Check(cached, r)
}
const (
OnHeader = "header"
OnQuery = "query"