mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix: middleware bypass
This commit is contained in:
parent
f1eefde964
commit
71ca8c738e
7 changed files with 294 additions and 273 deletions
|
@ -77,11 +77,8 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return nil
|
||||
})
|
||||
}
|
||||
if rp, ok := mux.(routes.ReverseProxyRoute); ok {
|
||||
r = rp.ReverseProxy().WithContextValue(r)
|
||||
}
|
||||
if ep.middleware != nil {
|
||||
ep.middleware.ServeHTTP(mux.ServeHTTP, w, r)
|
||||
ep.middleware.ServeHTTP(mux.ServeHTTP, w, routes.WithRouteContext(r, mux))
|
||||
return
|
||||
}
|
||||
mux.ServeHTTP(w, r)
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -91,12 +91,12 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamName: func(req *http.Request) string { return reverseproxy.TryGetUpstreamName(req) },
|
||||
VarUpstreamScheme: func(req *http.Request) string { return reverseproxy.TryGetUpstreamScheme(req) },
|
||||
VarUpstreamHost: func(req *http.Request) string { return reverseproxy.TryGetUpstreamHost(req) },
|
||||
VarUpstreamPort: func(req *http.Request) string { return reverseproxy.TryGetUpstreamPort(req) },
|
||||
VarUpstreamAddr: func(req *http.Request) string { return reverseproxy.TryGetUpstreamAddr(req) },
|
||||
VarUpstreamURL: func(req *http.Request) string { return reverseproxy.TryGetUpstreamURL(req) },
|
||||
VarUpstreamName: routes.TryGetUpstreamName,
|
||||
VarUpstreamScheme: routes.TryGetUpstreamScheme,
|
||||
VarUpstreamHost: routes.TryGetUpstreamHost,
|
||||
VarUpstreamPort: routes.TryGetUpstreamPort,
|
||||
VarUpstreamAddr: routes.TryGetUpstreamAddr,
|
||||
VarUpstreamURL: routes.TryGetUpstreamURL,
|
||||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var reverseProxyContextKey = struct{}{}
|
||||
|
||||
func (rp *ReverseProxy) WithContextValue(r *http.Request) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), reverseProxyContextKey, rp))
|
||||
}
|
||||
|
||||
func TryGetReverseProxy(r *http.Request) *ReverseProxy {
|
||||
if rp, ok := r.Context().Value(reverseProxyContextKey).(*ReverseProxy); ok {
|
||||
return rp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TryGetUpstreamName(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamScheme(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Scheme
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamHost(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Hostname()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamPort(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Port()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamAddr(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamURL(r *http.Request) string {
|
||||
if rp := TryGetReverseProxy(r); rp != nil {
|
||||
return rp.TargetURL.String()
|
||||
}
|
||||
return ""
|
||||
}
|
74
internal/route/routes/context.go
Normal file
74
internal/route/routes/context.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type RouteContext struct{}
|
||||
|
||||
var routeContextKey = RouteContext{}
|
||||
|
||||
func WithRouteContext(r *http.Request, route HTTPRoute) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), routeContextKey, route))
|
||||
}
|
||||
|
||||
func TryGetRoute(r *http.Request) HTTPRoute {
|
||||
if route, ok := r.Context().Value(routeContextKey).(HTTPRoute); ok {
|
||||
return route
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tryGetURL(r *http.Request) *url.URL {
|
||||
if route := TryGetRoute(r); route != nil {
|
||||
u := route.TargetURL()
|
||||
if u != nil {
|
||||
return &u.URL
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TryGetUpstreamName(r *http.Request) string {
|
||||
if route := TryGetRoute(r); route != nil {
|
||||
return route.Name()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamScheme(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.Scheme
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamHost(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.Hostname()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamPort(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.Port()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamAddr(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamURL(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.String()
|
||||
}
|
||||
return ""
|
||||
}
|
|
@ -8,8 +8,8 @@ import (
|
|||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
|
@ -242,7 +242,7 @@ var checkers = map[string]struct {
|
|||
builder: func(args any) CheckFunc {
|
||||
route := args.(string)
|
||||
return func(_ Cache, r *http.Request) bool {
|
||||
return reverseproxy.TryGetUpstreamName(r) == route
|
||||
return routes.TryGetUpstreamName(r) == route
|
||||
}
|
||||
},
|
||||
},
|
||||
|
|
195
internal/route/rules/on_internal_test.go
Normal file
195
internal/route/rules/on_internal_test.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package rules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSplitAnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
input: "",
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
input: "rule",
|
||||
want: []string{"rule"},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
input: "rule1 & rule2",
|
||||
want: []string{"rule1", "rule2"},
|
||||
},
|
||||
{
|
||||
name: "multiple_newline",
|
||||
input: "rule1\n\nrule2",
|
||||
want: []string{"rule1", "rule2"},
|
||||
},
|
||||
{
|
||||
name: "multiple_newline_and",
|
||||
input: "rule1\nrule2 & rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "empty segment",
|
||||
input: "rule1\n& &rule2& rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "double_and",
|
||||
input: "rule1\nrule2 && rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "spaces_around",
|
||||
input: " rule1\nrule2 & rule3 ",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitAnd(tt.input)
|
||||
expect.Equal(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr gperr.Error
|
||||
}{
|
||||
// header
|
||||
{
|
||||
name: "header_valid_kv",
|
||||
input: "header Connection Upgrade",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header_valid_k",
|
||||
input: "header Connection",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header_missing_arg",
|
||||
input: "header",
|
||||
wantErr: ErrExpectKVOptionalV,
|
||||
},
|
||||
// query
|
||||
{
|
||||
name: "query_valid_kv",
|
||||
input: "query key value",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "query_valid_k",
|
||||
input: "query key",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "query_missing_arg",
|
||||
input: "query",
|
||||
wantErr: ErrExpectKVOptionalV,
|
||||
},
|
||||
{
|
||||
name: "cookie_valid_kv",
|
||||
input: "cookie key value",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "cookie_valid_k",
|
||||
input: "cookie key",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "cookie_missing_arg",
|
||||
input: "cookie",
|
||||
wantErr: ErrExpectKVOptionalV,
|
||||
},
|
||||
// method
|
||||
{
|
||||
name: "method_valid",
|
||||
input: "method GET",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "method_invalid",
|
||||
input: "method invalid",
|
||||
wantErr: ErrInvalidArguments,
|
||||
},
|
||||
{
|
||||
name: "method_missing_arg",
|
||||
input: "method",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
// path
|
||||
{
|
||||
name: "path_valid",
|
||||
input: "path /home",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "path_missing_arg",
|
||||
input: "path",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
// remote
|
||||
{
|
||||
name: "remote_valid",
|
||||
input: "remote 127.0.0.1",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "remote_invalid",
|
||||
input: "remote abcd",
|
||||
wantErr: ErrInvalidArguments,
|
||||
},
|
||||
{
|
||||
name: "remote_missing_arg",
|
||||
input: "remote",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
{
|
||||
name: "unknown_target",
|
||||
input: "unknown",
|
||||
wantErr: ErrInvalidOnTarget,
|
||||
},
|
||||
// route
|
||||
{
|
||||
name: "route_valid",
|
||||
input: "route example",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "route_missing_arg",
|
||||
input: "route",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
{
|
||||
name: "route_extra_arg",
|
||||
input: "route example1 example2",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
on := &RuleOn{}
|
||||
err := on.Parse(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
expect.HasError(t, tt.wantErr, err)
|
||||
} else {
|
||||
expect.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package rules
|
||||
package rules_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
@ -7,199 +7,13 @@ import (
|
|||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
. "github.com/yusing/go-proxy/internal/route/rules"
|
||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestSplitAnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
input: "",
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
input: "rule",
|
||||
want: []string{"rule"},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
input: "rule1 & rule2",
|
||||
want: []string{"rule1", "rule2"},
|
||||
},
|
||||
{
|
||||
name: "multiple_newline",
|
||||
input: "rule1\n\nrule2",
|
||||
want: []string{"rule1", "rule2"},
|
||||
},
|
||||
{
|
||||
name: "multiple_newline_and",
|
||||
input: "rule1\nrule2 & rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "empty segment",
|
||||
input: "rule1\n& &rule2& rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "double_and",
|
||||
input: "rule1\nrule2 && rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "spaces_around",
|
||||
input: " rule1\nrule2 & rule3 ",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitAnd(tt.input)
|
||||
ExpectEqual(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr gperr.Error
|
||||
}{
|
||||
// header
|
||||
{
|
||||
name: "header_valid_kv",
|
||||
input: "header Connection Upgrade",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header_valid_k",
|
||||
input: "header Connection",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header_missing_arg",
|
||||
input: "header",
|
||||
wantErr: ErrExpectKVOptionalV,
|
||||
},
|
||||
// query
|
||||
{
|
||||
name: "query_valid_kv",
|
||||
input: "query key value",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "query_valid_k",
|
||||
input: "query key",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "query_missing_arg",
|
||||
input: "query",
|
||||
wantErr: ErrExpectKVOptionalV,
|
||||
},
|
||||
{
|
||||
name: "cookie_valid_kv",
|
||||
input: "cookie key value",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "cookie_valid_k",
|
||||
input: "cookie key",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "cookie_missing_arg",
|
||||
input: "cookie",
|
||||
wantErr: ErrExpectKVOptionalV,
|
||||
},
|
||||
// method
|
||||
{
|
||||
name: "method_valid",
|
||||
input: "method GET",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "method_invalid",
|
||||
input: "method invalid",
|
||||
wantErr: ErrInvalidArguments,
|
||||
},
|
||||
{
|
||||
name: "method_missing_arg",
|
||||
input: "method",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
// path
|
||||
{
|
||||
name: "path_valid",
|
||||
input: "path /home",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "path_missing_arg",
|
||||
input: "path",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
// remote
|
||||
{
|
||||
name: "remote_valid",
|
||||
input: "remote 127.0.0.1",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "remote_invalid",
|
||||
input: "remote abcd",
|
||||
wantErr: ErrInvalidArguments,
|
||||
},
|
||||
{
|
||||
name: "remote_missing_arg",
|
||||
input: "remote",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
{
|
||||
name: "unknown_target",
|
||||
input: "unknown",
|
||||
wantErr: ErrInvalidOnTarget,
|
||||
},
|
||||
// route
|
||||
{
|
||||
name: "route_valid",
|
||||
input: "route example",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "route_missing_arg",
|
||||
input: "route",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
{
|
||||
name: "route_extra_arg",
|
||||
input: "route example1 example2",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
on := &RuleOn{}
|
||||
err := on.Parse(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
ExpectError(t, tt.wantErr, err)
|
||||
} else {
|
||||
ExpectNoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testCorrectness struct {
|
||||
name string
|
||||
checker string
|
||||
|
@ -284,7 +98,7 @@ func TestOnCorrectness(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "basic_auth_correct",
|
||||
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
|
||||
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
|
||||
input: &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
|
||||
|
@ -294,7 +108,7 @@ func TestOnCorrectness(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "basic_auth_incorrect",
|
||||
checker: "basic_auth user " + string(Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
|
||||
checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
|
||||
input: &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
|
||||
|
@ -305,7 +119,10 @@ func TestOnCorrectness(t *testing.T) {
|
|||
{
|
||||
name: "route_match",
|
||||
checker: "route example",
|
||||
input: reverseproxy.NewReverseProxy("example", nil, http.DefaultTransport).WithContextValue(&http.Request{}),
|
||||
input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{
|
||||
Alias: "example",
|
||||
Root: "/",
|
||||
}))),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
|
@ -354,12 +171,11 @@ func TestOnCorrectness(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
on, err := parseOn(tt.checker)
|
||||
ExpectNoError(t, err)
|
||||
var on RuleOn
|
||||
err := on.Parse(tt.checker)
|
||||
expect.NoError(t, err)
|
||||
got := on.Check(Cache{}, tt.input)
|
||||
if tt.want != got {
|
||||
t.Errorf("want %v, got %v", tt.want, got)
|
||||
}
|
||||
expect.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue