fix: middleware bypass

This commit is contained in:
yusing 2025-05-11 06:33:22 +08:00
parent f1eefde964
commit 71ca8c738e
7 changed files with 294 additions and 273 deletions

View file

@ -77,11 +77,8 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return nil
})
}
if rp, ok := mux.(routes.ReverseProxyRoute); ok {
r = rp.ReverseProxy().WithContextValue(r)
}
if ep.middleware != nil {
ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
ep.middleware.ServeHTTP(mux.ServeHTTP, w, routes.WithRouteContext(r, mux))
return
}
mux.ServeHTTP(w, r)

View file

@ -7,7 +7,7 @@ import (
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/route/routes"
)
type (
@ -91,12 +91,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return reverseproxy.TryGetUpstreamName(req) },
VarUpstreamScheme: func(req *http.Request) string { return reverseproxy.TryGetUpstreamScheme(req) },
VarUpstreamHost: func(req *http.Request) string { return reverseproxy.TryGetUpstreamHost(req) },
VarUpstreamPort: func(req *http.Request) string { return reverseproxy.TryGetUpstreamPort(req) },
VarUpstreamAddr: func(req *http.Request) string { return reverseproxy.TryGetUpstreamAddr(req) },
VarUpstreamURL: func(req *http.Request) string { return reverseproxy.TryGetUpstreamURL(req) },
VarUpstreamName: routes.TryGetUpstreamName,
VarUpstreamScheme: routes.TryGetUpstreamScheme,
VarUpstreamHost: routes.TryGetUpstreamHost,
VarUpstreamPort: routes.TryGetUpstreamPort,
VarUpstreamAddr: routes.TryGetUpstreamAddr,
VarUpstreamURL: routes.TryGetUpstreamURL,
}
var staticRespVarSubsMap = map[string]respVarGetter{

View file

@ -1,61 +0,0 @@
package reverseproxy
import (
"context"
"net/http"
)
var reverseProxyContextKey = struct{}{}
func (rp *ReverseProxy) WithContextValue(r *http.Request) *http.Request {
return r.WithContext(context.WithValue(r.Context(), reverseProxyContextKey, rp))
}
func TryGetReverseProxy(r *http.Request) *ReverseProxy {
if rp, ok := r.Context().Value(reverseProxyContextKey).(*ReverseProxy); ok {
return rp
}
return nil
}
func TryGetUpstreamName(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetName
}
return ""
}
func TryGetUpstreamScheme(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Scheme
}
return ""
}
func TryGetUpstreamHost(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Hostname()
}
return ""
}
func TryGetUpstreamPort(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Port()
}
return ""
}
func TryGetUpstreamAddr(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.Host
}
return ""
}
func TryGetUpstreamURL(r *http.Request) string {
if rp := TryGetReverseProxy(r); rp != nil {
return rp.TargetURL.String()
}
return ""
}

View file

@ -0,0 +1,74 @@
package routes
import (
"context"
"net/http"
"net/url"
)
type RouteContext struct{}
var routeContextKey = RouteContext{}
func WithRouteContext(r *http.Request, route HTTPRoute) *http.Request {
return r.WithContext(context.WithValue(r.Context(), routeContextKey, route))
}
func TryGetRoute(r *http.Request) HTTPRoute {
if route, ok := r.Context().Value(routeContextKey).(HTTPRoute); ok {
return route
}
return nil
}
func tryGetURL(r *http.Request) *url.URL {
if route := TryGetRoute(r); route != nil {
u := route.TargetURL()
if u != nil {
return &u.URL
}
}
return nil
}
func TryGetUpstreamName(r *http.Request) string {
if route := TryGetRoute(r); route != nil {
return route.Name()
}
return ""
}
func TryGetUpstreamScheme(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.Scheme
}
return ""
}
func TryGetUpstreamHost(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.Hostname()
}
return ""
}
func TryGetUpstreamPort(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.Port()
}
return ""
}
func TryGetUpstreamAddr(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.Host
}
return ""
}
func TryGetUpstreamURL(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.String()
}
return ""
}

View file

@ -8,8 +8,8 @@ import (
"github.com/gobwas/glob"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -242,7 +242,7 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
route := args.(string)
return func(_ Cache, r *http.Request) bool {
return reverseproxy.TryGetUpstreamName(r) == route
return routes.TryGetUpstreamName(r) == route
}
},
},

View file

@ -0,0 +1,195 @@
package rules
import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSplitAnd(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{
name: "empty",
input: "",
want: []string{},
},
{
name: "single",
input: "rule",
want: []string{"rule"},
},
{
name: "multiple",
input: "rule1 & rule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline",
input: "rule1\n\nrule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline_and",
input: "rule1\nrule2 & rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "empty segment",
input: "rule1\n& &rule2& rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "double_and",
input: "rule1\nrule2 && rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "spaces_around",
input: " rule1\nrule2 & rule3 ",
want: []string{"rule1", "rule2", "rule3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitAnd(tt.input)
expect.Equal(t, got, tt.want)
})
}
}
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr gperr.Error
}{
// header
{
name: "header_valid_kv",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_valid_k",
input: "header Connection",
wantErr: nil,
},
{
name: "header_missing_arg",
input: "header",
wantErr: ErrExpectKVOptionalV,
},
// query
{
name: "query_valid_kv",
input: "query key value",
wantErr: nil,
},
{
name: "query_valid_k",
input: "query key",
wantErr: nil,
},
{
name: "query_missing_arg",
input: "query",
wantErr: ErrExpectKVOptionalV,
},
{
name: "cookie_valid_kv",
input: "cookie key value",
wantErr: nil,
},
{
name: "cookie_valid_k",
input: "cookie key",
wantErr: nil,
},
{
name: "cookie_missing_arg",
input: "cookie",
wantErr: ErrExpectKVOptionalV,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method invalid",
wantErr: ErrInvalidArguments,
},
{
name: "method_missing_arg",
input: "method",
wantErr: ErrExpectOneArg,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_missing_arg",
input: "path",
wantErr: ErrExpectOneArg,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote abcd",
wantErr: ErrInvalidArguments,
},
{
name: "remote_missing_arg",
input: "remote",
wantErr: ErrExpectOneArg,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
// route
{
name: "route_valid",
input: "route example",
wantErr: nil,
},
{
name: "route_missing_arg",
input: "route",
wantErr: ErrExpectOneArg,
},
{
name: "route_extra_arg",
input: "route example1 example2",
wantErr: ErrExpectOneArg,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
expect.HasError(t, tt.wantErr, err)
} else {
expect.NoError(t, err)
}
})
}
}

View file

@ -1,4 +1,4 @@
package rules
package rules_test
import (
"encoding/base64"
@ -7,199 +7,13 @@ import (
"net/url"
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/routes"
. "github.com/yusing/go-proxy/internal/route/rules"
expect "github.com/yusing/go-proxy/internal/utils/testing"
"golang.org/x/crypto/bcrypt"
)
func TestSplitAnd(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{
name: "empty",
input: "",
want: []string{},
},
{
name: "single",
input: "rule",
want: []string{"rule"},
},
{
name: "multiple",
input: "rule1 & rule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline",
input: "rule1\n\nrule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_newline_and",
input: "rule1\nrule2 & rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "empty segment",
input: "rule1\n& &rule2& rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "double_and",
input: "rule1\nrule2 && rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "spaces_around",
input: " rule1\nrule2 & rule3 ",
want: []string{"rule1", "rule2", "rule3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitAnd(tt.input)
ExpectEqual(t, got, tt.want)
})
}
}
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr gperr.Error
}{
// header
{
name: "header_valid_kv",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_valid_k",
input: "header Connection",
wantErr: nil,
},
{
name: "header_missing_arg",
input: "header",
wantErr: ErrExpectKVOptionalV,
},
// query
{
name: "query_valid_kv",
input: "query key value",
wantErr: nil,
},
{
name: "query_valid_k",
input: "query key",
wantErr: nil,
},
{
name: "query_missing_arg",
input: "query",
wantErr: ErrExpectKVOptionalV,
},
{
name: "cookie_valid_kv",
input: "cookie key value",
wantErr: nil,
},
{
name: "cookie_valid_k",
input: "cookie key",
wantErr: nil,
},
{
name: "cookie_missing_arg",
input: "cookie",
wantErr: ErrExpectKVOptionalV,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method invalid",
wantErr: ErrInvalidArguments,
},
{
name: "method_missing_arg",
input: "method",
wantErr: ErrExpectOneArg,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_missing_arg",
input: "path",
wantErr: ErrExpectOneArg,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote abcd",
wantErr: ErrInvalidArguments,
},
{
name: "remote_missing_arg",
input: "remote",
wantErr: ErrExpectOneArg,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
// route
{
name: "route_valid",
input: "route example",
wantErr: nil,
},
{
name: "route_missing_arg",
input: "route",
wantErr: ErrExpectOneArg,
},
{
name: "route_extra_arg",
input: "route example1 example2",
wantErr: ErrExpectOneArg,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
type testCorrectness struct {
name string
checker string
@ -284,7 +98,7 @@ func TestOnCorrectness(t *testing.T) {
},
{
name: "basic_auth_correct",
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
@ -294,7 +108,7 @@ func TestOnCorrectness(t *testing.T) {
},
{
name: "basic_auth_incorrect",
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
@ -305,7 +119,10 @@ func TestOnCorrectness(t *testing.T) {
{
name: "route_match",
checker: "route example",
input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}),
input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{
Alias: "example",
Root: "/",
}))),
want: true,
},
{
@ -354,12 +171,11 @@ func TestOnCorrectness(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on, err := parseOn(tt.checker)
ExpectNoError(t, err)
var on RuleOn
err := on.Parse(tt.checker)
expect.NoError(t, err)
got := on.Check(Cache{}, tt.input)
if tt.want != got {
t.Errorf("want %v, got %v", tt.want, got)
}
expect.Equal(t, tt.want, got)
})
}
}