diff --git a/internal/net/gphttp/middleware/bypass.go b/internal/net/gphttp/middleware/bypass.go new file mode 100644 index 0000000..a5f1b0a --- /dev/null +++ b/internal/net/gphttp/middleware/bypass.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "net/http" + + "github.com/yusing/go-proxy/internal/route/rules" +) + +type Bypass []rules.RuleOn + +func (b Bypass) ShouldBypass(r *http.Request) bool { + cached := rules.NewCache() + defer cached.Release() + for _, rule := range b { + if rule.Check(cached, r) { + return true + } + } + return false +} + +type checkBypass struct { + bypass Bypass + modReq RequestModifier + modRes ResponseModifier +} + +func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) { + if c.modReq == nil || c.bypass.ShouldBypass(r) { + return true + } + return c.modReq.before(w, r) +} + +func (c *checkBypass) modifyResponse(resp *http.Response) error { + if c.modRes == nil || c.bypass.ShouldBypass(resp.Request) { + return nil + } + return c.modRes.modifyResponse(resp) +} + +func (m *Middleware) withCheckBypass() any { + if len(m.Bypass) > 0 { + modReq, _ := m.impl.(RequestModifier) + modRes, _ := m.impl.(ResponseModifier) + return &checkBypass{ + bypass: m.Bypass, + modReq: modReq, + modRes: modRes, + } + } + return m.impl +} diff --git a/internal/net/gphttp/middleware/bypass_test.go b/internal/net/gphttp/middleware/bypass_test.go new file mode 100644 index 0000000..728b866 --- /dev/null +++ b/internal/net/gphttp/middleware/bypass_test.go @@ -0,0 +1,131 @@ +package middleware_test + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + . "github.com/yusing/go-proxy/internal/net/gphttp/middleware" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" + "github.com/yusing/go-proxy/internal/net/types" + expect "github.com/yusing/go-proxy/internal/utils/testing" +) + +func noOpHandler(w http.ResponseWriter, r *http.Request) {} + +func TestBypassCIDR(t *testing.T) { + mr, err := ModifyRequest.New(map[string]any{ + "set_headers": map[string]string{ + "Test-Header": "test-value", + }, + "bypass": []string{"remote 127.0.0.1/32"}, + }) + expect.NoError(t, err) + + tests := []struct { + name string + remoteAddr string + expectBypass bool + }{ + {"bypass", "127.0.0.1:8080", true}, + {"no_bypass", "192.168.1.1:8080", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com", nil) + req.RemoteAddr = test.remoteAddr + recorder := httptest.NewRecorder() + mr.ModifyRequest(noOpHandler, recorder, req) + expect.NoError(t, err) + if test.expectBypass { + expect.Equal(t, req.Header.Get("Test-Header"), "") + } else { + expect.Equal(t, req.Header.Get("Test-Header"), "test-value") + } + }) + } +} + +func TestBypassPath(t *testing.T) { + mr, err := ModifyRequest.New(map[string]any{ + "bypass": []string{"path /test/*", "path /api"}, + "set_headers": map[string]string{ + "Test-Header": "test-value", + }, + }) + expect.NoError(t, err) + + tests := []struct { + name string + path string + expectBypass bool + }{ + {"bypass", "/test/123", true}, + {"bypass2", "/test/123/456", true}, + {"bypass3", "/api", true}, + {"no_bypass", "/test1/123/456", false}, + {"no_bypass2", "/api/123", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) + recorder := httptest.NewRecorder() + mr.ModifyRequest(noOpHandler, recorder, req) + expect.NoError(t, err) + if test.expectBypass { + expect.Equal(t, req.Header.Get("Test-Header"), "") + } else { + expect.Equal(t, req.Header.Get("Test-Header"), "test-value") + } + }) + } +} + +type fakeRoundTripper struct{} + +func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("")), + Request: req, + Header: make(http.Header), + }, nil +} + +func TestReverseProxyBypass(t *testing.T) { + rp := reverseproxy.NewReverseProxy("test", types.MustParseURL("http://example.com"), fakeRoundTripper{}) + err := PatchReverseProxy(rp, map[string]OptionsRaw{ + "response": { + "bypass": "path /test/* | path /api", + "set_headers": map[string]string{ + "Test-Header": "test-value", + }, + }, + }) + expect.NoError(t, err) + tests := []struct { + name string + path string + expectBypass bool + }{ + {"bypass", "/test/123", true}, + {"bypass2", "/test/123/456", true}, + {"bypass3", "/api", true}, + {"no_bypass", "/test1/123/456", false}, + {"no_bypass2", "/api/123", false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) + recorder := httptest.NewRecorder() + rp.ServeHTTP(recorder, req) + if test.expectBypass { + expect.Equal(t, recorder.Header().Get("Test-Header"), "") + } else { + expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value") + } + }) + } +} diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index c6b5fd8..a39ffda 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -2,6 +2,7 @@ package middleware import ( "encoding/json" + "maps" "net/http" "reflect" "sort" @@ -23,16 +24,22 @@ type ( ImplNewFunc = func() any OptionsRaw = map[string]any - Middleware struct { - name string - construct ImplNewFunc - impl any + commonOptions = struct { // priority is only applied for ReverseProxy. // // Middleware compose follows the order of the slice // // Default is 10, 0 is the highest - priority int + Priority int `json:"priority"` + Bypass Bypass `json:"bypass"` + } + + Middleware struct { + name string + construct ImplNewFunc + impl any + + commonOptions } ByPriority []*Middleware @@ -55,7 +62,7 @@ const DefaultPriority = 10 func (m ByPriority) Len() int { return len(m) } func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] } -func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority } +func (m ByPriority) Less(i, j int) bool { return m[i].Priority < m[j].Priority } func NewMiddleware[ImplType any]() *Middleware { // type check @@ -107,21 +114,22 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error { if len(optsRaw) == 0 { return nil } - priority, ok := optsRaw["priority"].(int) - if ok { - m.priority = priority - // remove priority for deserialization, restore later - delete(optsRaw, "priority") - defer func() { - optsRaw["priority"] = priority - }() - } else { - m.priority = DefaultPriority + commonOpts := map[string]any{ + "priority": optsRaw["priority"], + "bypass": optsRaw["bypass"], + } + if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil { + return err + } + optsRaw = maps.Clone(optsRaw) + for k := range commonOpts { + delete(optsRaw, k) } return utils.MapUnmarshalValidate(optsRaw, m.impl) } func (m *Middleware) finalize() error { + m.impl = m.withCheckBypass() if finalizer, ok := m.impl.(MiddlewareFinalizer); ok { finalizer.finalize() return nil @@ -159,10 +167,16 @@ func (m *Middleware) String() string { } func (m *Middleware) MarshalJSON() ([]byte, error) { + type allOptions struct { + commonOptions + any + } return json.MarshalIndent(map[string]any{ - "name": m.name, - "options": m.impl, - "priority": m.priority, + "name": m.name, + "options": allOptions{ + commonOptions: m.commonOptions, + any: m.impl, + }, }, "", " ") } diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 7793e2e..2c427a7 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -16,6 +16,10 @@ type RuleOn struct { checker Checker } +func (on *RuleOn) Check(cached Cache, r *http.Request) bool { + return on.checker.Check(cached, r) +} + const ( OnHeader = "header" OnQuery = "query"