mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix(middleware): fix route bypass matching
- replace upstream headers approach with context value
This commit is contained in:
parent
1ce607029a
commit
75ee0e63bd
9 changed files with 116 additions and 77 deletions
|
@ -77,6 +77,9 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return nil
|
||||
})
|
||||
}
|
||||
if rp, ok := mux.(routes.ReverseProxyRoute); ok {
|
||||
r = rp.ReverseProxy().WithContextValue(r)
|
||||
}
|
||||
if ep.middleware != nil {
|
||||
ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
|
||||
return
|
||||
|
|
|
@ -20,11 +20,6 @@ const (
|
|||
HeaderContentType = "Content-Type"
|
||||
HeaderContentLength = "Content-Length"
|
||||
|
||||
HeaderUpstreamName = "X-Godoxy-Upstream-Name"
|
||||
HeaderUpstreamScheme = "X-Godoxy-Upstream-Scheme"
|
||||
HeaderUpstreamHost = "X-Godoxy-Upstream-Host"
|
||||
HeaderUpstreamPort = "X-Godoxy-Upstream-Port"
|
||||
|
||||
HeaderGoDoxyCheckRedirect = "X-Godoxy-Check-Redirect"
|
||||
)
|
||||
|
||||
|
|
|
@ -7,9 +7,13 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/entrypoint"
|
||||
. "github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
routeTypes "github.com/yusing/go-proxy/internal/route/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -129,3 +133,40 @@ func TestReverseProxyBypass(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntrypointBypassRoute(t *testing.T) {
|
||||
go http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test"))
|
||||
}))
|
||||
entry := entrypoint.NewEntrypoint()
|
||||
r := &route.Route{
|
||||
Alias: "test-route",
|
||||
Port: routeTypes.Port{
|
||||
Proxy: 8080,
|
||||
},
|
||||
}
|
||||
err := entry.SetMiddlewares([]map[string]any{
|
||||
{
|
||||
"use": "redirectHTTP",
|
||||
"bypass": []string{"route test-route"},
|
||||
},
|
||||
{
|
||||
"use": "response",
|
||||
"set_headers": map[string]string{
|
||||
"Test-Header": "test-value",
|
||||
},
|
||||
},
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
|
||||
err = r.Validate()
|
||||
expect.NoError(t, err)
|
||||
r.Start(task.RootTask("test", false))
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "http://test-route.example.com", nil)
|
||||
entry.ServeHTTP(recorder, req)
|
||||
expect.Equal(t, recorder.Code, http.StatusOK, "should bypass http redirect")
|
||||
expect.Equal(t, recorder.Body.String(), "test")
|
||||
expect.Equal(t, recorder.Header().Get("Test-Header"), "test-value")
|
||||
}
|
||||
|
|
|
@ -222,7 +222,6 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
|
|||
|
||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||
sort.Sort(ByPriority(middlewares))
|
||||
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||
|
||||
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
)
|
||||
|
||||
// internal use only.
|
||||
type setUpstreamHeaders struct {
|
||||
Name, Scheme, Host, Port string
|
||||
}
|
||||
|
||||
var suh = NewMiddleware[setUpstreamHeaders]()
|
||||
|
||||
func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
|
||||
m, err := suh.New(OptionsRaw{
|
||||
"name": rp.TargetName,
|
||||
"scheme": rp.TargetURL.Scheme,
|
||||
"host": rp.TargetURL.Hostname(),
|
||||
"port": rp.TargetURL.Port(),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// before implements RequestModifier.
|
||||
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
|
||||
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
|
||||
return true
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -91,31 +91,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *http.Request) string {
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
if upPort != "" {
|
||||
return upHost + ":" + upPort
|
||||
}
|
||||
return upHost
|
||||
},
|
||||
VarUpstreamURL: func(req *http.Request) string {
|
||||
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return ""
|
||||
}
|
||||
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
|
||||
upAddr := upHost
|
||||
if upPort != "" {
|
||||
upAddr += ":" + upPort
|
||||
}
|
||||
return upScheme + "://" + upAddr
|
||||
},
|
||||
VarUpstreamName: func(req *http.Request) string { return reverseproxy.TryGetUpstreamName(req) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return reverseproxy.TryGetUpstreamScheme(req) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return reverseproxy.TryGetUpstreamHost(req) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return reverseproxy.TryGetUpstreamPort(req) },
|
||||
VarUpstreamAddr: func(req *http.Request) string { return reverseproxy.TryGetUpstreamAddr(req) },
|
||||
VarUpstreamURL: func(req *http.Request) string { return reverseproxy.TryGetUpstreamURL(req) },
|
||||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
|
|
61
internal/net/gphttp/reverseproxy/context.go
Normal file
61
internal/net/gphttp/reverseproxy/context.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var reverseProxyContextKey = struct{}{}
|
||||
|
||||
func (rp *ReverseProxy) WithContextValue(r *http.Request) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), reverseProxyContextKey, rp))
|
||||
}
|
||||
|
||||
func TryGetReverseProxy(r *http.Request) *ReverseProxy {
|
||||
if rp, ok := r.Context().Value(reverseProxyContextKey).(*ReverseProxy); ok {
|
||||
return rp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TryGetUpstreamName(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamScheme(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Scheme
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamHost(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Hostname()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamPort(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Port()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamAddr(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamURL(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.String()
|
||||
}
|
||||
return ""
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
@ -242,7 +242,7 @@ var checkers = map[string]struct {
|
|||
builder: func(args any) CheckFunc {
|
||||
route := args.(string)
|
||||
return func(_ Cache, r *http.Request) bool {
|
||||
return r.Header.Get(httpheaders.HeaderUpstreamName) == route
|
||||
return reverseproxy.TryGetUpstreamName(r) == route
|
||||
}
|
||||
},
|
||||
},
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -305,11 +305,7 @@ func TestOnCorrectness(t *testing.T) {
|
|||
{
|
||||
name: "route_match",
|
||||
checker: "route example",
|
||||
input: &http.Request{
|
||||
Header: http.Header{
|
||||
httpheaders.HeaderUpstreamName: {"example"},
|
||||
},
|
||||
},
|
||||
input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Add table
Reference in a new issue