From 75ee0e63bd7d90549f04ca758271c8dd0ef420bf Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 8 May 2025 17:49:36 +0800 Subject: [PATCH] fix(middleware): fix route bypass matching - replace upstream headers approach with context value --- internal/entrypoint/entrypoint.go | 3 + internal/net/gphttp/httpheaders/utils.go | 5 -- internal/net/gphttp/middleware/bypass_test.go | 41 +++++++++++++ internal/net/gphttp/middleware/middleware.go | 1 - .../gphttp/middleware/set_upstream_headers.go | 37 ----------- internal/net/gphttp/middleware/vars.go | 33 +++------- internal/net/gphttp/reverseproxy/context.go | 61 +++++++++++++++++++ internal/route/rules/on.go | 4 +- internal/route/rules/on_test.go | 8 +-- 9 files changed, 116 insertions(+), 77 deletions(-) delete mode 100644 internal/net/gphttp/middleware/set_upstream_headers.go create mode 100644 internal/net/gphttp/reverseproxy/context.go diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 0f6d202..ec3842d 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -77,6 +77,9 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { return nil }) } + if rp, ok := mux.(routes.ReverseProxyRoute); ok { + r = rp.ReverseProxy().WithContextValue(r) + } if ep.middleware != nil { ep.middleware.ServeHTTP(mux.ServeHTTP, w, r) return diff --git a/internal/net/gphttp/httpheaders/utils.go b/internal/net/gphttp/httpheaders/utils.go index 2f348ca..4a032cc 100644 --- a/internal/net/gphttp/httpheaders/utils.go +++ b/internal/net/gphttp/httpheaders/utils.go @@ -20,11 +20,6 @@ const ( HeaderContentType = "Content-Type" HeaderContentLength = "Content-Length" - HeaderUpstreamName = "X-Godoxy-Upstream-Name" - HeaderUpstreamScheme = "X-Godoxy-Upstream-Scheme" - HeaderUpstreamHost = "X-Godoxy-Upstream-Host" - HeaderUpstreamPort = "X-Godoxy-Upstream-Port" - HeaderGoDoxyCheckRedirect = "X-Godoxy-Check-Redirect" ) diff --git a/internal/net/gphttp/middleware/bypass_test.go b/internal/net/gphttp/middleware/bypass_test.go index 728b866..217d642 100644 --- a/internal/net/gphttp/middleware/bypass_test.go +++ b/internal/net/gphttp/middleware/bypass_test.go @@ -7,9 +7,13 @@ import ( "strings" "testing" + "github.com/yusing/go-proxy/internal/entrypoint" . "github.com/yusing/go-proxy/internal/net/gphttp/middleware" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/route" + routeTypes "github.com/yusing/go-proxy/internal/route/types" + "github.com/yusing/go-proxy/internal/task" expect "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -129,3 +133,40 @@ func TestReverseProxyBypass(t *testing.T) { }) } } + +func TestEntrypointBypassRoute(t *testing.T) { + go http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("test")) + })) + entry := entrypoint.NewEntrypoint() + r := &route.Route{ + Alias: "test-route", + Port: routeTypes.Port{ + Proxy: 8080, + }, + } + err := entry.SetMiddlewares([]map[string]any{ + { + "use": "redirectHTTP", + "bypass": []string{"route test-route"}, + }, + { + "use": "response", + "set_headers": map[string]string{ + "Test-Header": "test-value", + }, + }, + }) + expect.NoError(t, err) + + err = r.Validate() + expect.NoError(t, err) + r.Start(task.RootTask("test", false)) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://test-route.example.com", nil) + entry.ServeHTTP(recorder, req) + expect.Equal(t, recorder.Code, http.StatusOK, "should bypass http redirect") + expect.Equal(t, recorder.Body.String(), "test") + expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value") +} diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index a39ffda..9f060ee 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -222,7 +222,6 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) ( func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { sort.Sort(ByPriority(middlewares)) - middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...) mid := NewMiddlewareChain(rp.TargetName, middlewares) diff --git a/internal/net/gphttp/middleware/set_upstream_headers.go b/internal/net/gphttp/middleware/set_upstream_headers.go deleted file mode 100644 index 434c4cd..0000000 --- a/internal/net/gphttp/middleware/set_upstream_headers.go +++ /dev/null @@ -1,37 +0,0 @@ -package middleware - -import ( - "net/http" - - "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" - "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" -) - -// internal use only. -type setUpstreamHeaders struct { - Name, Scheme, Host, Port string -} - -var suh = NewMiddleware[setUpstreamHeaders]() - -func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware { - m, err := suh.New(OptionsRaw{ - "name": rp.TargetName, - "scheme": rp.TargetURL.Scheme, - "host": rp.TargetURL.Hostname(), - "port": rp.TargetURL.Port(), - }) - if err != nil { - panic(err) - } - return m -} - -// before implements RequestModifier. -func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - r.Header.Set(httpheaders.HeaderUpstreamName, s.Name) - r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme) - r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host) - r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port) - return true -} diff --git a/internal/net/gphttp/middleware/vars.go b/internal/net/gphttp/middleware/vars.go index 472ca72..a612f2f 100644 --- a/internal/net/gphttp/middleware/vars.go +++ b/internal/net/gphttp/middleware/vars.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" ) type ( @@ -91,31 +91,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ return "" }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, - VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) }, - VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) }, - VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) }, - VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) }, - VarUpstreamAddr: func(req *http.Request) string { - upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) - upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) - if upPort != "" { - return upHost + ":" + upPort - } - return upHost - }, - VarUpstreamURL: func(req *http.Request) string { - upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme) - if upScheme == "" { - return "" - } - upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) - upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) - upAddr := upHost - if upPort != "" { - upAddr += ":" + upPort - } - return upScheme + "://" + upAddr - }, + VarUpstreamName: func(req *http.Request) string { return reverseproxy.TryGetUpstreamName(req) }, + VarUpstreamScheme: func(req *http.Request) string { return reverseproxy.TryGetUpstreamScheme(req) }, + VarUpstreamHost: func(req *http.Request) string { return reverseproxy.TryGetUpstreamHost(req) }, + VarUpstreamPort: func(req *http.Request) string { return reverseproxy.TryGetUpstreamPort(req) }, + VarUpstreamAddr: func(req *http.Request) string { return reverseproxy.TryGetUpstreamAddr(req) }, + VarUpstreamURL: func(req *http.Request) string { return reverseproxy.TryGetUpstreamURL(req) }, } var staticRespVarSubsMap = map[string]respVarGetter{ diff --git a/internal/net/gphttp/reverseproxy/context.go b/internal/net/gphttp/reverseproxy/context.go new file mode 100644 index 0000000..c1a8d14 --- /dev/null +++ b/internal/net/gphttp/reverseproxy/context.go @@ -0,0 +1,61 @@ +package reverseproxy + +import ( + "context" + "net/http" +) + +var reverseProxyContextKey = struct{}{} + +func (rp *ReverseProxy) WithContextValue(r *http.Request) *http.Request { + return r.WithContext(context.WithValue(r.Context(), reverseProxyContextKey, rp)) +} + +func TryGetReverseProxy(r *http.Request) *ReverseProxy { + if rp, ok := r.Context().Value(reverseProxyContextKey).(*ReverseProxy); ok { + return rp + } + return nil +} + +func TryGetUpstreamName(r *http.Request) string { + if rp := TryGetReverseProxy(r); rp != nil { + return rp.TargetName + } + return "" +} + +func TryGetUpstreamScheme(r *http.Request) string { + if rp := TryGetReverseProxy(r); rp != nil { + return rp.TargetURL.Scheme + } + return "" +} + +func TryGetUpstreamHost(r *http.Request) string { + if rp := TryGetReverseProxy(r); rp != nil { + return rp.TargetURL.Hostname() + } + return "" +} + +func TryGetUpstreamPort(r *http.Request) string { + if rp := TryGetReverseProxy(r); rp != nil { + return rp.TargetURL.Port() + } + return "" +} + +func TryGetUpstreamAddr(r *http.Request) string { + if rp := TryGetReverseProxy(r); rp != nil { + return rp.TargetURL.Host + } + return "" +} + +func TryGetUpstreamURL(r *http.Request) string { + if rp := TryGetReverseProxy(r); rp != nil { + return rp.TargetURL.String() + } + return "" +} diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 099f08b..c5549f3 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -8,7 +8,7 @@ import ( "github.com/gobwas/glob" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -242,7 +242,7 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { route := args.(string) return func(_ Cache, r *http.Request) bool { - return r.Header.Get(httpheaders.HeaderUpstreamName) == route + return reverseproxy.TryGetUpstreamName(r) == route } }, }, diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index 1a1ed1c..88acd77 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" + "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" . "github.com/yusing/go-proxy/internal/utils/testing" "golang.org/x/crypto/bcrypt" ) @@ -305,11 +305,7 @@ func TestOnCorrectness(t *testing.T) { { name: "route_match", checker: "route example", - input: &http.Request{ - Header: http.Header{ - httpheaders.HeaderUpstreamName: {"example"}, - }, - }, + input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}), want: true, }, {