From 71ca8c738ed32e71a1d22a6bb9b294ae5a203899 Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 11 May 2025 06:33:22 +0800 Subject: [PATCH] fix: middleware bypass --- internal/entrypoint/entrypoint.go | 5 +- internal/net/gphttp/middleware/vars.go | 14 +- internal/net/gphttp/reverseproxy/context.go | 61 ------ internal/route/routes/context.go | 74 +++++++ internal/route/rules/on.go | 4 +- internal/route/rules/on_internal_test.go | 195 ++++++++++++++++++ internal/route/rules/on_test.go | 214 ++------------------ 7 files changed, 294 insertions(+), 273 deletions(-) delete mode 100644 internal/net/gphttp/reverseproxy/context.go create mode 100644 internal/route/routes/context.go create mode 100644 internal/route/rules/on_internal_test.go diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index ec3842d..e3683db 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -77,11 +77,8 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { return nil }) } - if rp, ok := mux.(routes.ReverseProxyRoute); ok { - r = rp.ReverseProxy().WithContextValue(r) - } if ep.middleware != nil { - ep.middleware.ServeHTTP(mux.ServeHTTP, w, r) + ep.middleware.ServeHTTP(mux.ServeHTTP, w, routes.WithRouteContext(r, mux)) return } mux.ServeHTTP(w, r) diff --git a/internal/net/gphttp/middleware/vars.go b/internal/net/gphttp/middleware/vars.go index a612f2f..6c5f7fc 100644 --- a/internal/net/gphttp/middleware/vars.go +++ b/internal/net/gphttp/middleware/vars.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" + "github.com/yusing/go-proxy/internal/route/routes" ) type ( @@ -91,12 +91,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ return "" }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, - VarUpstreamName: func(req *http.Request) string { return reverseproxy.TryGetUpstreamName(req) }, - VarUpstreamScheme: func(req *http.Request) string { return reverseproxy.TryGetUpstreamScheme(req) }, - VarUpstreamHost: func(req *http.Request) string { return reverseproxy.TryGetUpstreamHost(req) }, - VarUpstreamPort: func(req *http.Request) string { return reverseproxy.TryGetUpstreamPort(req) }, - VarUpstreamAddr: func(req *http.Request) string { return reverseproxy.TryGetUpstreamAddr(req) }, - VarUpstreamURL: func(req *http.Request) string { return reverseproxy.TryGetUpstreamURL(req) }, + VarUpstreamName: routes.TryGetUpstreamName, + VarUpstreamScheme: routes.TryGetUpstreamScheme, + VarUpstreamHost: routes.TryGetUpstreamHost, + VarUpstreamPort: routes.TryGetUpstreamPort, + VarUpstreamAddr: routes.TryGetUpstreamAddr, + VarUpstreamURL: routes.TryGetUpstreamURL, } var staticRespVarSubsMap = map[string]respVarGetter{ diff --git a/internal/net/gphttp/reverseproxy/context.go b/internal/net/gphttp/reverseproxy/context.go deleted file mode 100644 index c1a8d14..0000000 --- a/internal/net/gphttp/reverseproxy/context.go +++ /dev/null @@ -1,61 +0,0 @@ -package reverseproxy - -import ( - "context" - "net/http" -) - -var reverseProxyContextKey = struct{}{} - -func (rp *ReverseProxy) WithContextValue(r *http.Request) *http.Request { - return r.WithContext(context.WithValue(r.Context(), reverseProxyContextKey, rp)) -} - -func TryGetReverseProxy(r *http.Request) *ReverseProxy { - if rp, ok := r.Context().Value(reverseProxyContextKey).(*ReverseProxy); ok { - return rp - } - return nil -} - -func TryGetUpstreamName(r *http.Request) string { - if rp := TryGetReverseProxy(r); rp != nil { - return rp.TargetName - } - return "" -} - -func TryGetUpstreamScheme(r *http.Request) string { - if rp := TryGetReverseProxy(r); rp != nil { - return rp.TargetURL.Scheme - } - return "" -} - -func TryGetUpstreamHost(r *http.Request) string { - if rp := TryGetReverseProxy(r); rp != nil { - return rp.TargetURL.Hostname() - } - return "" -} - -func TryGetUpstreamPort(r *http.Request) string { - if rp := TryGetReverseProxy(r); rp != nil { - return rp.TargetURL.Port() - } - return "" -} - -func TryGetUpstreamAddr(r *http.Request) string { - if rp := TryGetReverseProxy(r); rp != nil { - return rp.TargetURL.Host - } - return "" -} - -func TryGetUpstreamURL(r *http.Request) string { - if rp := TryGetReverseProxy(r); rp != nil { - return rp.TargetURL.String() - } - return "" -} diff --git a/internal/route/routes/context.go b/internal/route/routes/context.go new file mode 100644 index 0000000..7426ccf --- /dev/null +++ b/internal/route/routes/context.go @@ -0,0 +1,74 @@ +package routes + +import ( + "context" + "net/http" + "net/url" +) + +type RouteContext struct{} + +var routeContextKey = RouteContext{} + +func WithRouteContext(r *http.Request, route HTTPRoute) *http.Request { + return r.WithContext(context.WithValue(r.Context(), routeContextKey, route)) +} + +func TryGetRoute(r *http.Request) HTTPRoute { + if route, ok := r.Context().Value(routeContextKey).(HTTPRoute); ok { + return route + } + return nil +} + +func tryGetURL(r *http.Request) *url.URL { + if route := TryGetRoute(r); route != nil { + u := route.TargetURL() + if u != nil { + return &u.URL + } + } + return nil +} + +func TryGetUpstreamName(r *http.Request) string { + if route := TryGetRoute(r); route != nil { + return route.Name() + } + return "" +} + +func TryGetUpstreamScheme(r *http.Request) string { + if u := tryGetURL(r); u != nil { + return u.Scheme + } + return "" +} + +func TryGetUpstreamHost(r *http.Request) string { + if u := tryGetURL(r); u != nil { + return u.Hostname() + } + return "" +} + +func TryGetUpstreamPort(r *http.Request) string { + if u := tryGetURL(r); u != nil { + return u.Port() + } + return "" +} + +func TryGetUpstreamAddr(r *http.Request) string { + if u := tryGetURL(r); u != nil { + return u.Host + } + return "" +} + +func TryGetUpstreamURL(r *http.Request) string { + if u := tryGetURL(r); u != nil { + return u.String() + } + return "" +} diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index c5549f3..2eb776c 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -8,8 +8,8 @@ import ( "github.com/gobwas/glob" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -242,7 +242,7 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { route := args.(string) return func(_ Cache, r *http.Request) bool { - return reverseproxy.TryGetUpstreamName(r) == route + return routes.TryGetUpstreamName(r) == route } }, }, diff --git a/internal/route/rules/on_internal_test.go b/internal/route/rules/on_internal_test.go new file mode 100644 index 0000000..08ad0dc --- /dev/null +++ b/internal/route/rules/on_internal_test.go @@ -0,0 +1,195 @@ +package rules + +import ( + "testing" + + "github.com/yusing/go-proxy/internal/gperr" + expect "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestSplitAnd(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "empty", + input: "", + want: []string{}, + }, + { + name: "single", + input: "rule", + want: []string{"rule"}, + }, + { + name: "multiple", + input: "rule1 & rule2", + want: []string{"rule1", "rule2"}, + }, + { + name: "multiple_newline", + input: "rule1\n\nrule2", + want: []string{"rule1", "rule2"}, + }, + { + name: "multiple_newline_and", + input: "rule1\nrule2 & rule3", + want: []string{"rule1", "rule2", "rule3"}, + }, + { + name: "empty segment", + input: "rule1\n& &rule2& rule3", + want: []string{"rule1", "rule2", "rule3"}, + }, + { + name: "double_and", + input: "rule1\nrule2 && rule3", + want: []string{"rule1", "rule2", "rule3"}, + }, + { + name: "spaces_around", + input: " rule1\nrule2 & rule3 ", + want: []string{"rule1", "rule2", "rule3"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := splitAnd(tt.input) + expect.Equal(t, got, tt.want) + }) + } +} + +func TestParseOn(t *testing.T) { + tests := []struct { + name string + input string + wantErr gperr.Error + }{ + // header + { + name: "header_valid_kv", + input: "header Connection Upgrade", + wantErr: nil, + }, + { + name: "header_valid_k", + input: "header Connection", + wantErr: nil, + }, + { + name: "header_missing_arg", + input: "header", + wantErr: ErrExpectKVOptionalV, + }, + // query + { + name: "query_valid_kv", + input: "query key value", + wantErr: nil, + }, + { + name: "query_valid_k", + input: "query key", + wantErr: nil, + }, + { + name: "query_missing_arg", + input: "query", + wantErr: ErrExpectKVOptionalV, + }, + { + name: "cookie_valid_kv", + input: "cookie key value", + wantErr: nil, + }, + { + name: "cookie_valid_k", + input: "cookie key", + wantErr: nil, + }, + { + name: "cookie_missing_arg", + input: "cookie", + wantErr: ErrExpectKVOptionalV, + }, + // method + { + name: "method_valid", + input: "method GET", + wantErr: nil, + }, + { + name: "method_invalid", + input: "method invalid", + wantErr: ErrInvalidArguments, + }, + { + name: "method_missing_arg", + input: "method", + wantErr: ErrExpectOneArg, + }, + // path + { + name: "path_valid", + input: "path /home", + wantErr: nil, + }, + { + name: "path_missing_arg", + input: "path", + wantErr: ErrExpectOneArg, + }, + // remote + { + name: "remote_valid", + input: "remote 127.0.0.1", + wantErr: nil, + }, + { + name: "remote_invalid", + input: "remote abcd", + wantErr: ErrInvalidArguments, + }, + { + name: "remote_missing_arg", + input: "remote", + wantErr: ErrExpectOneArg, + }, + { + name: "unknown_target", + input: "unknown", + wantErr: ErrInvalidOnTarget, + }, + // route + { + name: "route_valid", + input: "route example", + wantErr: nil, + }, + { + name: "route_missing_arg", + input: "route", + wantErr: ErrExpectOneArg, + }, + { + name: "route_extra_arg", + input: "route example1 example2", + wantErr: ErrExpectOneArg, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + on := &RuleOn{} + err := on.Parse(tt.input) + if tt.wantErr != nil { + expect.HasError(t, tt.wantErr, err) + } else { + expect.NoError(t, err) + } + }) + } +} diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index 88acd77..fdb8180 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -1,4 +1,4 @@ -package rules +package rules_test import ( "encoding/base64" @@ -7,199 +7,13 @@ import ( "net/url" "testing" - "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" - . "github.com/yusing/go-proxy/internal/utils/testing" + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/routes" + . "github.com/yusing/go-proxy/internal/route/rules" + expect "github.com/yusing/go-proxy/internal/utils/testing" "golang.org/x/crypto/bcrypt" ) -func TestSplitAnd(t *testing.T) { - tests := []struct { - name string - input string - want []string - }{ - { - name: "empty", - input: "", - want: []string{}, - }, - { - name: "single", - input: "rule", - want: []string{"rule"}, - }, - { - name: "multiple", - input: "rule1 & rule2", - want: []string{"rule1", "rule2"}, - }, - { - name: "multiple_newline", - input: "rule1\n\nrule2", - want: []string{"rule1", "rule2"}, - }, - { - name: "multiple_newline_and", - input: "rule1\nrule2 & rule3", - want: []string{"rule1", "rule2", "rule3"}, - }, - { - name: "empty segment", - input: "rule1\n& &rule2& rule3", - want: []string{"rule1", "rule2", "rule3"}, - }, - { - name: "double_and", - input: "rule1\nrule2 && rule3", - want: []string{"rule1", "rule2", "rule3"}, - }, - { - name: "spaces_around", - input: " rule1\nrule2 & rule3 ", - want: []string{"rule1", "rule2", "rule3"}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := splitAnd(tt.input) - ExpectEqual(t, got, tt.want) - }) - } -} - -func TestParseOn(t *testing.T) { - tests := []struct { - name string - input string - wantErr gperr.Error - }{ - // header - { - name: "header_valid_kv", - input: "header Connection Upgrade", - wantErr: nil, - }, - { - name: "header_valid_k", - input: "header Connection", - wantErr: nil, - }, - { - name: "header_missing_arg", - input: "header", - wantErr: ErrExpectKVOptionalV, - }, - // query - { - name: "query_valid_kv", - input: "query key value", - wantErr: nil, - }, - { - name: "query_valid_k", - input: "query key", - wantErr: nil, - }, - { - name: "query_missing_arg", - input: "query", - wantErr: ErrExpectKVOptionalV, - }, - { - name: "cookie_valid_kv", - input: "cookie key value", - wantErr: nil, - }, - { - name: "cookie_valid_k", - input: "cookie key", - wantErr: nil, - }, - { - name: "cookie_missing_arg", - input: "cookie", - wantErr: ErrExpectKVOptionalV, - }, - // method - { - name: "method_valid", - input: "method GET", - wantErr: nil, - }, - { - name: "method_invalid", - input: "method invalid", - wantErr: ErrInvalidArguments, - }, - { - name: "method_missing_arg", - input: "method", - wantErr: ErrExpectOneArg, - }, - // path - { - name: "path_valid", - input: "path /home", - wantErr: nil, - }, - { - name: "path_missing_arg", - input: "path", - wantErr: ErrExpectOneArg, - }, - // remote - { - name: "remote_valid", - input: "remote 127.0.0.1", - wantErr: nil, - }, - { - name: "remote_invalid", - input: "remote abcd", - wantErr: ErrInvalidArguments, - }, - { - name: "remote_missing_arg", - input: "remote", - wantErr: ErrExpectOneArg, - }, - { - name: "unknown_target", - input: "unknown", - wantErr: ErrInvalidOnTarget, - }, - // route - { - name: "route_valid", - input: "route example", - wantErr: nil, - }, - { - name: "route_missing_arg", - input: "route", - wantErr: ErrExpectOneArg, - }, - { - name: "route_extra_arg", - input: "route example1 example2", - wantErr: ErrExpectOneArg, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - on := &RuleOn{} - err := on.Parse(tt.input) - if tt.wantErr != nil { - ExpectError(t, tt.wantErr, err) - } else { - ExpectNoError(t, err) - } - }) - } -} - type testCorrectness struct { name string checker string @@ -284,7 +98,7 @@ func TestOnCorrectness(t *testing.T) { }, { name: "basic_auth_correct", - checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), input: &http.Request{ Header: http.Header{ "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password" @@ -294,7 +108,7 @@ func TestOnCorrectness(t *testing.T) { }, { name: "basic_auth_incorrect", - checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), input: &http.Request{ Header: http.Header{ "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong" @@ -305,7 +119,10 @@ func TestOnCorrectness(t *testing.T) { { name: "route_match", checker: "route example", - input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}), + input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{ + Alias: "example", + Root: "/", + }))), want: true, }, { @@ -354,12 +171,11 @@ func TestOnCorrectness(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - on, err := parseOn(tt.checker) - ExpectNoError(t, err) + var on RuleOn + err := on.Parse(tt.checker) + expect.NoError(t, err) got := on.Check(Cache{}, tt.input) - if tt.want != got { - t.Errorf("want %v, got %v", tt.want, got) - } + expect.Equal(t, tt.want, got) }) } }