From 62d3d200e6ee8eb955107439e65c9d3b82704b6f Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 5 May 2025 19:34:24 +0800 Subject: [PATCH] feat(rules.on): support route directive --- internal/route/rules/on.go | 17 ++++++++++++++++ internal/route/rules/on_test.go | 35 ++++++++++++++++++++++++++++++++ internal/route/rules/validate.go | 7 +++++++ 3 files changed, 59 insertions(+) diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 61cc752..099f08b 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -8,6 +8,7 @@ import ( "github.com/gobwas/glob" "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -31,6 +32,7 @@ const ( OnPath = "path" OnRemote = "remote" OnBasicAuth = "basic_auth" + OnRoute = "route" ) var checkers = map[string]struct { @@ -229,6 +231,21 @@ var checkers = map[string]struct { } }, }, + OnRoute: { + help: Help{ + command: OnRoute, + args: map[string]string{ + "route": "the route name", + }, + }, + validate: validateSingleArg, + builder: func(args any) CheckFunc { + route := args.(string) + return func(_ Cache, r *http.Request) bool { + return r.Header.Get(httpheaders.HeaderUpstreamName) == route + } + }, + }, } var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index 9b4d3ee..1a1ed1c 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" . "github.com/yusing/go-proxy/internal/utils/testing" "golang.org/x/crypto/bcrypt" ) @@ -168,6 +169,22 @@ func TestParseOn(t *testing.T) { input: "unknown", wantErr: ErrInvalidOnTarget, }, + // route + { + name: "route_valid", + input: "route example", + wantErr: nil, + }, + { + name: "route_missing_arg", + input: "route", + wantErr: ErrExpectOneArg, + }, + { + name: "route_extra_arg", + input: "route example1 example2", + wantErr: ErrExpectOneArg, + }, } for _, tt := range tests { @@ -285,6 +302,24 @@ func TestOnCorrectness(t *testing.T) { }, want: false, }, + { + name: "route_match", + checker: "route example", + input: &http.Request{ + Header: http.Header{ + httpheaders.HeaderUpstreamName: {"example"}, + }, + }, + want: true, + }, + { + name: "route_no_match", + checker: "route example", + input: &http.Request{ + Header: http.Header{}, + }, + want: false, + }, } tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request { diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index 7f45f10..266e01d 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -30,6 +30,13 @@ func (t *Tuple[T1, T2]) String() string { return fmt.Sprintf("%v:%v", t.First, t.Second) } +func validateSingleArg(args []string) (any, gperr.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + return args[0], nil +} + // toStrTuple returns *StrTuple. func toStrTuple(args []string) (any, gperr.Error) { if len(args) != 2 {