fix(middleware): fix route bypass matching

- replace upstream headers approach with context value
This commit is contained in:
yusing 2025-05-08 17:49:36 +08:00
parent 1ce607029a
commit 75ee0e63bd
9 changed files with 116 additions and 77 deletions

View file

@ -77,6 +77,9 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return nil
})
}
if rp, ok := mux.(routes.ReverseProxyRoute); ok {
r = rp.ReverseProxy().WithContextValue(r)
}
if ep.middleware != nil {
ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
return

View file

@ -20,11 +20,6 @@ const (
HeaderContentType = "Content-Type"
HeaderContentLength = "Content-Length"
HeaderUpstreamName = "X-Godoxy-Upstream-Name"
HeaderUpstreamScheme = "X-Godoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-Godoxy-Upstream-Host"
HeaderUpstreamPort = "X-Godoxy-Upstream-Port"
HeaderGoDoxyCheckRedirect = "X-Godoxy-Check-Redirect"
)

View file

@ -7,9 +7,13 @@ import (
"strings"
"testing"
"github.com/yusing/go-proxy/internal/entrypoint"
. "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route"
routeTypes "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/task"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -129,3 +133,40 @@ func TestReverseProxyBypass(t *testing.T) {
})
}
}
func TestEntrypointBypassRoute(t *testing.T) {
go http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
}))
entry := entrypoint.NewEntrypoint()
r := &route.Route{
Alias: "test-route",
Port: routeTypes.Port{
Proxy: 8080,
},
}
err := entry.SetMiddlewares([]map[string]any{
{
"use": "redirectHTTP",
"bypass": []string{"route test-route"},
},
{
"use": "response",
"set_headers": map[string]string{
"Test-Header": "test-value",
},
},
})
expect.NoError(t, err)
err = r.Validate()
expect.NoError(t, err)
r.Start(task.RootTask("test", false))
recorder := httptest.NewRecorder()
req := httptest.NewRequest("GET", "http://test-route.example.com", nil)
entry.ServeHTTP(recorder, req)
expect.Equal(t, recorder.Code, http.StatusOK, "should bypass http redirect")
expect.Equal(t, recorder.Body.String(), "test")
expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value")
}

View file

@ -222,7 +222,6 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
sort.Sort(ByPriority(middlewares))
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
mid := NewMiddlewareChain(rp.TargetName, middlewares)

View file

@ -1,37 +0,0 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
)
// internal use only.
type setUpstreamHeaders struct {
Name, Scheme, Host, Port string
}
var suh = NewMiddleware[setUpstreamHeaders]()
func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
m, err := suh.New(OptionsRaw{
"name": rp.TargetName,
"scheme": rp.TargetURL.Scheme,
"host": rp.TargetURL.Hostname(),
"port": rp.TargetURL.Port(),
})
if err != nil {
panic(err)
}
return m
}
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true
}

View file

@ -7,7 +7,7 @@ import (
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
)
type (
@ -91,31 +91,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" {
return upHost + ":" + upPort
}
return upHost
},
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" {
return ""
}
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort
}
return upScheme + "://" + upAddr
},
VarUpstreamName: func(req *http.Request) string { return reverseproxy.TryGetUpstreamName(req) },
VarUpstreamScheme: func(req *http.Request) string { return reverseproxy.TryGetUpstreamScheme(req) },
VarUpstreamHost: func(req *http.Request) string { return reverseproxy.TryGetUpstreamHost(req) },
VarUpstreamPort: func(req *http.Request) string { return reverseproxy.TryGetUpstreamPort(req) },
VarUpstreamAddr: func(req *http.Request) string { return reverseproxy.TryGetUpstreamAddr(req) },
VarUpstreamURL: func(req *http.Request) string { return reverseproxy.TryGetUpstreamURL(req) },
}
var staticRespVarSubsMap = map[string]respVarGetter{

View file

@ -0,0 +1,61 @@
package reverseproxy
import (
"context"
"net/http"
)
var reverseProxyContextKey = struct{}{}
func (rp *ReverseProxy) WithContextValue(r *http.Request) *http.Request {
return r.WithContext(context.WithValue(r.Context(), reverseProxyContextKey, rp))
}
func TryGetReverseProxy(r *http.Request) *ReverseProxy {
if rp, ok := r.Context().Value(reverseProxyContextKey).(*ReverseProxy); ok {
return rp
}
return nil
}
func TryGetUpstreamName(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetName
}
return ""
}
func TryGetUpstreamScheme(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Scheme
}
return ""
}
func TryGetUpstreamHost(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Hostname()
}
return ""
}
func TryGetUpstreamPort(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Port()
}
return ""
}
func TryGetUpstreamAddr(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Host
}
return ""
}
func TryGetUpstreamURL(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.String()
}
return ""
}

View file

@ -8,7 +8,7 @@ import (
"github.com/gobwas/glob"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -242,7 +242,7 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
route := args.(string)
return func(_ Cache, r *http.Request) bool {
return r.Header.Get(httpheaders.HeaderUpstreamName) == route
return reverseproxy.TryGetUpstreamName(r) == route
}
},
},

View file

@ -8,7 +8,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
. "github.com/yusing/go-proxy/internal/utils/testing"
"golang.org/x/crypto/bcrypt"
)
@ -305,11 +305,7 @@ func TestOnCorrectness(t *testing.T) {
{
name: "route_match",
checker: "route example",
input: &http.Request{
Header: http.Header{
httpheaders.HeaderUpstreamName: {"example"},
},
},
input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}),
want: true,
},
{