From 0ce7f29976b5a409e5e7816b8a4b7da2101b5c74 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 11 Jan 2025 12:22:42 +0800 Subject: [PATCH] fix proxy rules behavior and implemented a few more rules and commands, dependencies upgrade --- .trunk/trunk.yaml | 4 +- cmd/main.go | 7 +- go.mod | 4 +- go.sum | 3 +- internal/route/http.go | 2 +- internal/route/rules/cache.go | 117 ++++++++++++ internal/route/rules/check_on.go | 34 ++++ internal/route/rules/command.go | 51 +++++ internal/route/rules/crypto.go | 27 +++ internal/route/rules/do.go | 224 +++++++++++++--------- internal/route/rules/do_test.go | 140 ++++++++++++++ internal/route/rules/errors.go | 2 + internal/route/rules/fields.go | 142 ++++++++++++++ internal/route/rules/on.go | 180 +++++++++--------- internal/route/rules/on_test.go | 89 +++++++++ internal/route/rules/parser.go | 1 + internal/route/rules/parser_test.go | 10 + internal/route/rules/rules.go | 65 ++++--- internal/route/rules/rules_test.go | 277 ++++------------------------ internal/route/rules/validate.go | 55 +++++- 20 files changed, 991 insertions(+), 443 deletions(-) create mode 100644 internal/route/rules/cache.go create mode 100644 internal/route/rules/check_on.go create mode 100644 internal/route/rules/command.go create mode 100644 internal/route/rules/crypto.go create mode 100644 internal/route/rules/do_test.go create mode 100644 internal/route/rules/fields.go create mode 100644 internal/route/rules/on_test.go diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 66cd116..d90c70f 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -23,7 +23,7 @@ lint: enabled: - hadolint@2.12.1-beta - actionlint@1.7.6 - - checkov@3.2.350 + - checkov@3.2.352 - git-diff-check - gofmt@1.20.4 - golangci-lint@1.63.4 @@ -32,7 +32,7 @@ lint: - prettier@3.4.2 - shellcheck@0.10.0 - shfmt@3.6.0 - - trufflehog@3.88.1 + - trufflehog@3.88.2 actions: disabled: - trunk-announce diff --git a/cmd/main.go b/cmd/main.go index 152336c..dc27715 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -20,6 +20,8 @@ import ( "github.com/yusing/go-proxy/pkg" ) +var rawLogger = log.New(os.Stdout, "", 0) + func main() { args := common.GetArgs() @@ -31,12 +33,12 @@ func main() { if err := query.ReloadServer(); err != nil { E.LogFatal("server reload error", err) } - logging.Info().Msg("ok") + rawLogger.Println("ok") return case common.CommandListIcons: icons, err := internal.ListAvailableIcons() if err != nil { - log.Fatal(err) + rawLogger.Fatal(err) } printJSON(icons) return @@ -139,6 +141,5 @@ func printJSON(obj any) { if err != nil { logging.Fatal().Err(err).Send() } - rawLogger := log.New(os.Stdout, "", 0) rawLogger.Print(string(j)) // raw output for convenience using "jq" } diff --git a/go.mod b/go.mod index 0094652..54fdeba 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/prometheus/client_golang v1.20.5 github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/rs/zerolog v1.33.0 + golang.org/x/crypto v0.32.0 golang.org/x/net v0.34.0 golang.org/x/text v0.21.0 golang.org/x/time v0.9.0 @@ -43,7 +44,7 @@ require ( github.com/google/go-querystring v1.1.0 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.62 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect @@ -65,7 +66,6 @@ require ( go.opentelemetry.io/otel/metric v1.33.0 // indirect go.opentelemetry.io/otel/sdk v1.30.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect - golang.org/x/crypto v0.32.0 // indirect golang.org/x/mod v0.22.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.10.0 // indirect diff --git a/go.sum b/go.sum index 9ce4691..aa6efb6 100644 --- a/go.sum +++ b/go.sum @@ -86,8 +86,9 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= diff --git a/internal/route/http.go b/internal/route/http.go index e5ce760..ef8c100 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -139,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error { } if len(r.Raw.Rules) > 0 { - r.handler = r.Raw.Rules.BuildHandler(r.handler) + r.handler = r.Raw.Rules.BuildHandler(r.TargetName(), r.handler) } if r.HealthMon != nil { diff --git a/internal/route/rules/cache.go b/internal/route/rules/cache.go new file mode 100644 index 0000000..532f508 --- /dev/null +++ b/internal/route/rules/cache.go @@ -0,0 +1,117 @@ +package rules + +import ( + "net" + "net/http" + "net/url" + "sync" +) + +// Cache is a map of cached values for a request. +// It prevents the same value from being parsed multiple times. +type ( + Cache map[string]any + UpdateFunc[T any] func(T) T +) + +const ( + CacheKeyQueries = "queries" + CacheKeyCookies = "cookies" + CacheKeyRemoteIP = "remote_ip" + CacheKeyBasicAuth = "basic_auth" +) + +var cacheKeys = []string{ + CacheKeyQueries, + CacheKeyCookies, + CacheKeyRemoteIP, + CacheKeyBasicAuth, +} + +var cachePool = &sync.Pool{ + New: func() any { + return make(Cache) + }, +} + +// NewCache returns a new Cached. +func NewCache() Cache { + return cachePool.Get().(Cache) +} + +// Release clear the contents of the Cached and returns it to the pool. +func (c Cache) Release() { + for _, k := range cacheKeys { + delete(c, k) + } + cachePool.Put(c) +} + +// GetQueries returns the queries. +// If r does not have queries, an empty map is returned. +func (c Cache) GetQueries(r *http.Request) url.Values { + v, ok := c[CacheKeyQueries] + if !ok { + v = r.URL.Query() + c[CacheKeyQueries] = v + } + return v.(url.Values) +} + +func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) { + queries := c.GetQueries(r) + update(queries) + r.URL.RawQuery = queries.Encode() +} + +// GetCookies returns the cookies. +// If r does not have cookies, an empty slice is returned. +func (c Cache) GetCookies(r *http.Request) []*http.Cookie { + v, ok := c[CacheKeyCookies] + if !ok { + v = r.Cookies() + c[CacheKeyCookies] = v + } + return v.([]*http.Cookie) +} + +func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) { + cookies := update(c.GetCookies(r)) + c[CacheKeyCookies] = cookies + r.Header.Del("Cookie") + for _, cookie := range cookies { + r.AddCookie(cookie) + } +} + +// GetRemoteIP returns the remote ip address. +// If r.RemoteAddr is not a valid ip address, nil is returned. +func (c Cache) GetRemoteIP(r *http.Request) net.IP { + v, ok := c[CacheKeyRemoteIP] + if !ok { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + v = net.ParseIP(host) + c[CacheKeyRemoteIP] = v + } + return v.(net.IP) +} + +// GetBasicAuth returns *Credentials the basic auth username and password. +// If r does not have basic auth, nil is returned. +func (c Cache) GetBasicAuth(r *http.Request) *Credentials { + v, ok := c[CacheKeyBasicAuth] + if !ok { + u, p, ok := r.BasicAuth() + if ok { + v = &Credentials{u, []byte(p)} + c[CacheKeyBasicAuth] = v + } else { + c[CacheKeyBasicAuth] = nil + return nil + } + } + return v.(*Credentials) +} diff --git a/internal/route/rules/check_on.go b/internal/route/rules/check_on.go new file mode 100644 index 0000000..389c9aa --- /dev/null +++ b/internal/route/rules/check_on.go @@ -0,0 +1,34 @@ +package rules + +import "net/http" + +type ( + CheckFunc func(cached Cache, r *http.Request) bool + Checker interface { + Check(cached Cache, r *http.Request) bool + } + CheckMatchSingle []Checker + CheckMatchAll []Checker +) + +func (checker CheckFunc) Check(cached Cache, r *http.Request) bool { + return checker(cached, r) +} + +func (checkers CheckMatchSingle) Check(cached Cache, r *http.Request) bool { + for _, check := range checkers { + if check.Check(cached, r) { + return true + } + } + return false +} + +func (checkers CheckMatchAll) Check(cached Cache, r *http.Request) bool { + for _, check := range checkers { + if !check.Check(cached, r) { + return false + } + } + return true +} diff --git a/internal/route/rules/command.go b/internal/route/rules/command.go new file mode 100644 index 0000000..3136066 --- /dev/null +++ b/internal/route/rules/command.go @@ -0,0 +1,51 @@ +package rules + +import "net/http" + +type ( + CommandHandler interface { + // CommandHandler can read and modify the values + // then handle the request + // finally proceed to next command (or return) base on situation + Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) + } + // StaticCommand will run then proceed to next command or reverse proxy. + StaticCommand http.HandlerFunc + // ReturningCommand will run then return immediately. + ReturningCommand http.HandlerFunc + // DynamicCommand will return base on the request + // and can raed or modify the values. + DynamicCommand func(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) + // BypassCommand will skip all the following commands + // and directly return to reverse proxy. + BypassCommand struct{} + // Commands is a slice of CommandHandler. + Commands []CommandHandler +) + +func (c StaticCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { + c(w, r) + return true +} + +func (c ReturningCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { + c(w, r) + return false +} + +func (c DynamicCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { + return c(cached, w, r) +} + +func (c BypassCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { + return true +} + +func (c Commands) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { + for _, cmd := range c { + if !cmd.Handle(cached, w, r) { + return false + } + } + return true +} diff --git a/internal/route/rules/crypto.go b/internal/route/rules/crypto.go new file mode 100644 index 0000000..3b05a1f --- /dev/null +++ b/internal/route/rules/crypto.go @@ -0,0 +1,27 @@ +package rules + +import "golang.org/x/crypto/bcrypt" + +type ( + HashedCrendentials struct { + Username string + CheckMatch func(inputPwd []byte) bool + } + Credentials struct { + Username string + Password []byte + } +) + +func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendentials { + return &HashedCrendentials{username, func(inputPwd []byte) bool { + return bcrypt.CompareHashAndPassword(hashedPassword, inputPwd) == nil + }} +} + +func (hc *HashedCrendentials) Match(cred *Credentials) bool { + if cred == nil { + return false + } + return hc.Username == cred.Username && hc.CheckMatch(cred.Password) +} diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index fed6ae1..f5edfb3 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -16,28 +16,27 @@ import ( type ( Command struct { raw string - exec *CommandExecutor - } - CommandExecutor struct { - directive string - http.HandlerFunc - proceed bool + exec CommandHandler } ) const ( - CommandRewrite = "rewrite" - CommandServe = "serve" - CommandProxy = "proxy" - CommandRedirect = "redirect" - CommandError = "error" - CommandBypass = "bypass" + CommandRewrite = "rewrite" + CommandServe = "serve" + CommandProxy = "proxy" + CommandRedirect = "redirect" + CommandError = "error" + CommandRequireBasicAuth = "require_basic_auth" + CommandSet = "set" + CommandAdd = "add" + CommandRemove = "remove" + CommandBypass = "bypass" ) var commands = map[string]struct { help Help validate ValidateFunc - build func(args any) *CommandExecutor + build func(args any) CommandHandler }{ CommandRewrite: { help: Help{ @@ -53,25 +52,22 @@ var commands = map[string]struct { } return validateURLPaths(args) }, - build: func(args any) *CommandExecutor { + build: func(args any) CommandHandler { a := args.([]string) orig, repl := a[0], a[1] - return &CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - if len(path) > 0 && path[0] != '/' { - path = "/" + path - } - if !strings.HasPrefix(path, orig) { - return - } - path = repl + path[len(orig):] - r.URL.Path = path - r.URL.RawPath = r.URL.EscapedPath() - r.RequestURI = r.URL.RequestURI() - }, - proceed: true, - } + return StaticCommand(func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + if len(path) > 0 && path[0] != '/' { + path = "/" + path + } + if !strings.HasPrefix(path, orig) { + return + } + path = repl + path[len(orig):] + r.URL.Path = path + r.URL.RawPath = r.URL.EscapedPath() + r.RequestURI = r.URL.RequestURI() + }) }, }, CommandServe: { @@ -82,14 +78,11 @@ var commands = map[string]struct { }, }, validate: validateFSPath, - build: func(args any) *CommandExecutor { + build: func(args any) CommandHandler { root := args.(string) - return &CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) - }, - proceed: false, - } + return ReturningCommand(func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) + }) }, }, CommandRedirect: { @@ -100,14 +93,11 @@ var commands = map[string]struct { }, }, validate: validateURL, - build: func(args any) *CommandExecutor { + build: func(args any) CommandHandler { target := args.(types.URL).String() - return &CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, target, http.StatusTemporaryRedirect) - }, - proceed: false, - } + return ReturningCommand(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target, http.StatusTemporaryRedirect) + }) }, }, CommandError: { @@ -130,17 +120,34 @@ var commands = map[string]struct { if !gphttp.IsStatusCodeValid(code) { return nil, ErrInvalidArguments.Subject(codeStr) } - return []any{code, text}, nil + return &Tuple[int, string]{code, text}, nil }, - build: func(args any) *CommandExecutor { - a := args.([]any) - code, text := a[0].(int), a[1].(string) - return &CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - http.Error(w, text, code) - }, - proceed: false, + build: func(args any) CommandHandler { + code, text := args.(*Tuple[int, string]).Unpack() + return ReturningCommand(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, text, code) + }) + }, + }, + CommandRequireBasicAuth: { + help: Help{ + command: CommandRequireBasicAuth, + args: map[string]string{ + "realm": "the authentication realm", + }, + }, + validate: func(args []string) (any, E.Error) { + if len(args) == 1 { + return args[0], nil } + return nil, ErrExpectOneArg + }, + build: func(args any) CommandHandler { + realm := args.(string) + return ReturningCommand(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + }) }, }, CommandProxy: { @@ -151,30 +158,69 @@ var commands = map[string]struct { }, }, validate: validateAbsoluteURL, - build: func(args any) *CommandExecutor { + build: func(args any) CommandHandler { target := args.(types.URL) if target.Scheme == "" { target.Scheme = "http" } rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) - return &CommandExecutor{ - HandlerFunc: rp.ServeHTTP, - proceed: false, - } + return ReturningCommand(rp.ServeHTTP) + }, + }, + CommandSet: { + help: Help{ + command: CommandSet, + args: map[string]string{ + "field": "the field to set", + "value": "the value to set", + }, + }, + validate: func(args []string) (any, E.Error) { + return validateModField(ModFieldSet, args) + }, + build: func(args any) CommandHandler { + return args.(CommandHandler) + }, + }, + CommandAdd: { + help: Help{ + command: CommandAdd, + args: map[string]string{ + "field": "the field to add", + "value": "the value to add", + }, + }, + validate: func(args []string) (any, E.Error) { + return validateModField(ModFieldAdd, args) + }, + build: func(args any) CommandHandler { + return args.(CommandHandler) + }, + }, + CommandRemove: { + help: Help{ + command: CommandRemove, + args: map[string]string{ + "field": "the field to remove", + }, + }, + validate: func(args []string) (any, E.Error) { + return validateModField(ModFieldRemove, args) + }, + build: func(args any) CommandHandler { + return args.(CommandHandler) }, }, } // Parse implements strutils.Parser. func (cmd *Command) Parse(v string) error { - cmd.raw = v - lines := strutils.SplitLine(v) if len(lines) == 0 { return nil } - executors := make([]*CommandExecutor, 0, len(lines)) + executors := make([]CommandHandler, 0, len(lines)) for _, line := range lines { if line == "" { continue @@ -189,7 +235,7 @@ func (cmd *Command) Parse(v string) error { if len(args) != 0 { return ErrInvalidArguments.Subject(directive) } - executors = append(executors, nil) + executors = append(executors, BypassCommand{}) continue } @@ -202,48 +248,58 @@ func (cmd *Command) Parse(v string) error { return err.Subject(directive).Withf("%s", builder.help.String()) } - exec := builder.build(validArgs) - exec.directive = directive - executors = append(executors, exec) + executors = append(executors, builder.build(validArgs)) + } + + if len(executors) == 0 { + return nil } exec, err := buildCmd(executors) if err != nil { return err } + + cmd.raw = v cmd.exec = exec return nil } -func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) { +func buildCmd(executors []CommandHandler) (CommandHandler, error) { for i, exec := range executors { - if !exec.proceed && i != len(executors)-1 { - return nil, ErrInvalidCommandSequence. - Withf("%s cannot follow %s", exec, executors[i+1]) + switch exec.(type) { + case ReturningCommand, BypassCommand: + if i != len(executors)-1 { + return nil, ErrInvalidCommandSequence. + Withf("a returning / bypass command must be the last command") + } } } - return &CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - for _, exec := range executors { - exec.HandlerFunc(w, r) - } - }, - proceed: executors[len(executors)-1].proceed, - }, nil + + return Commands(executors), nil } +// Command is purely "bypass" or empty. func (cmd *Command) isBypass() bool { - return cmd.exec == nil + if cmd == nil { + return true + } + switch cmd := cmd.exec.(type) { + case BypassCommand: + return true + case Commands: + // bypass command is always the last one + _, ok := cmd[len(cmd)-1].(BypassCommand) + return ok + default: + return false + } } func (cmd *Command) String() string { return cmd.raw } -func (cmd *Command) MarshalJSON() ([]byte, error) { - return []byte("\"" + cmd.String() + "\""), nil -} - -func (exec *CommandExecutor) String() string { - return exec.directive +func (cmd *Command) MarshalText() ([]byte, error) { + return []byte(cmd.String()), nil } diff --git a/internal/route/rules/do_test.go b/internal/route/rules/do_test.go new file mode 100644 index 0000000..5c7b944 --- /dev/null +++ b/internal/route/rules/do_test.go @@ -0,0 +1,140 @@ +package rules + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestParseCommands(t *testing.T) { + tests := []struct { + name string + input string + wantErr error + }{ + // bypass tests + { + name: "bypass_valid", + input: "bypass", + wantErr: nil, + }, + { + name: "bypass_invalid_with_args", + input: "bypass /", + wantErr: ErrInvalidArguments, + }, + // rewrite tests + { + name: "rewrite_valid", + input: "rewrite / /foo/bar", + wantErr: nil, + }, + { + name: "rewrite_missing_target", + input: "rewrite /", + wantErr: ErrInvalidArguments, + }, + { + name: "rewrite_too_many_args", + input: "rewrite / / /", + wantErr: ErrInvalidArguments, + }, + { + name: "rewrite_no_leading_slash", + input: "rewrite abc /", + wantErr: ErrInvalidArguments, + }, + // serve tests + { + name: "serve_valid", + input: "serve /var/www", + wantErr: nil, + }, + { + name: "serve_missing_path", + input: "serve ", + wantErr: ErrInvalidArguments, + }, + { + name: "serve_too_many_args", + input: "serve / / /", + wantErr: ErrInvalidArguments, + }, + // redirect tests + { + name: "redirect_valid", + input: "redirect /", + wantErr: nil, + }, + { + name: "redirect_too_many_args", + input: "redirect / /", + wantErr: ErrInvalidArguments, + }, + // error directive tests + { + name: "error_valid", + input: "error 404 Not\\ Found", + wantErr: nil, + }, + { + name: "error_missing_status_code", + input: "error Not\\ Found", + wantErr: ErrInvalidArguments, + }, + { + name: "error_too_many_args", + input: "error 404 Not\\ Found extra", + wantErr: ErrInvalidArguments, + }, + { + name: "error_no_escaped_space", + input: "error 404 Not Found", + wantErr: ErrInvalidArguments, + }, + { + name: "error_invalid_status_code", + input: "error 123 abc", + wantErr: ErrInvalidArguments, + }, + // proxy directive tests + { + name: "proxy_valid", + input: "proxy http://localhost:8080", + wantErr: nil, + }, + { + name: "proxy_missing_target", + input: "proxy", + wantErr: ErrInvalidArguments, + }, + { + name: "proxy_too_many_args", + input: "proxy http://localhost:8080 extra", + wantErr: ErrInvalidArguments, + }, + { + name: "proxy_invalid_url", + input: "proxy invalid_url", + wantErr: ErrInvalidArguments, + }, + // unknown directive test + { + name: "unknown_directive", + input: "unknown /", + wantErr: ErrUnknownDirective, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := Command{} + err := cmd.Parse(tt.input) + if tt.wantErr != nil { + ExpectError(t, tt.wantErr, err) + } else { + ExpectNoError(t, err) + } + }) + } +} diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index ad79fb5..6a9e852 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -9,7 +9,9 @@ var ( ErrInvalidArguments = E.New("invalid arguments") ErrInvalidOnTarget = E.New("invalid `rule.on` target") ErrInvalidCommandSequence = E.New("invalid command sequence") + ErrInvalidSetTarget = E.New("invalid `rule.set` target") + ErrExpectNoArg = ErrInvalidArguments.Withf("expect no arg") ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") ) diff --git a/internal/route/rules/fields.go b/internal/route/rules/fields.go new file mode 100644 index 0000000..91fdf72 --- /dev/null +++ b/internal/route/rules/fields.go @@ -0,0 +1,142 @@ +package rules + +import ( + "net/http" + "net/url" +) + +type ( + FieldHandler struct { + set, add, remove CommandHandler + } + FieldModifier string +) + +const ( + ModFieldSet FieldModifier = "set" + ModFieldAdd FieldModifier = "add" + ModFieldRemove FieldModifier = "remove" +) + +const ( + FieldHeader = "header" + FieldQuery = "query" + FieldCookie = "cookie" +) + +var modFields = map[string]struct { + help Help + validate ValidateFunc + builder func(args any) *FieldHandler +}{ + FieldHeader: { + help: Help{ + command: FieldHeader, + args: map[string]string{ + "key": "the header key", + "value": "the header value", + }, + }, + validate: toStrTuple, + builder: func(args any) *FieldHandler { + k, v := args.(*StrTuple).Unpack() + return &FieldHandler{ + set: StaticCommand(func(w http.ResponseWriter, r *http.Request) { + w.Header()[k] = []string{v} + }), + add: StaticCommand(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h[k] = append(h[k], v) + }), + remove: StaticCommand(func(w http.ResponseWriter, r *http.Request) { + delete(w.Header(), k) + }), + } + }, + }, + FieldQuery: { + help: Help{ + command: FieldQuery, + args: map[string]string{ + "key": "the query key", + "value": "the query value", + }, + }, + validate: toStrTuple, + builder: func(args any) *FieldHandler { + k, v := args.(*StrTuple).Unpack() + return &FieldHandler{ + set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { + cached.UpdateQueries(r, func(queries url.Values) { + queries.Set(k, v) + }) + return true + }), + add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { + cached.UpdateQueries(r, func(queries url.Values) { + queries.Add(k, v) + }) + return true + }), + remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { + cached.UpdateQueries(r, func(queries url.Values) { + queries.Del(k) + }) + return true + }), + } + }, + }, + FieldCookie: { + help: Help{ + command: FieldCookie, + args: map[string]string{ + "key": "the cookie key", + "value": "the cookie value", + }, + }, + validate: toStrTuple, + builder: func(args any) *FieldHandler { + k, v := args.(*StrTuple).Unpack() + return &FieldHandler{ + set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { + cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + for i, c := range cookies { + if c.Name == k { + cookies[i].Value = v + return cookies + } + } + return append(cookies, &http.Cookie{Name: k, Value: v}) + }) + return true + }), + add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { + cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + return append(cookies, &http.Cookie{Name: k, Value: v}) + }) + return true + }), + remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { + cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + index := -1 + for i, c := range cookies { + if c.Name == k { + index = i + break + } + } + if index != -1 { + if len(cookies) == 1 { + return []*http.Cookie{} + } + return append(cookies[:index], cookies[index+1:]...) + } + return cookies + }) + return true + }), + } + }, + }, +} diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 5f9fccd..8a80ac8 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -1,37 +1,34 @@ package rules import ( - "net" "net/http" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) -type ( - RuleOn struct { - raw string - check CheckFulfill - } - CheckFulfill func(r *http.Request) bool - Checkers []CheckFulfill -) +type RuleOn struct { + raw string + checker Checker +} const ( - OnHeader = "header" - OnQuery = "query" - OnCookie = "cookie" - OnForm = "form" - OnPostForm = "postform" - OnMethod = "method" - OnPath = "path" - OnRemote = "remote" + OnHeader = "header" + OnQuery = "query" + OnCookie = "cookie" + OnForm = "form" + OnPostForm = "postform" + OnMethod = "method" + OnPath = "path" + OnRemote = "remote" + OnBasicAuth = "basic_auth" ) var checkers = map[string]struct { help Help validate ValidateFunc - check func(r *http.Request, args any) bool + builder func(args any) CheckFunc }{ OnHeader: { help: Help{ @@ -42,8 +39,11 @@ var checkers = map[string]struct { }, }, validate: toStrTuple, - check: func(r *http.Request, args any) bool { - return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second + builder: func(args any) CheckFunc { + k, v := args.(*StrTuple).Unpack() + return func(cached Cache, r *http.Request) bool { + return r.Header.Get(k) == v + } }, }, OnQuery: { @@ -55,8 +55,17 @@ var checkers = map[string]struct { }, }, validate: toStrTuple, - check: func(r *http.Request, args any) bool { - return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second + builder: func(args any) CheckFunc { + k, v := args.(*StrTuple).Unpack() + return func(cached Cache, r *http.Request) bool { + queries := cached.GetQueries(r)[k] + for _, query := range queries { + if query == v { + return true + } + } + return false + } }, }, OnCookie: { @@ -68,14 +77,18 @@ var checkers = map[string]struct { }, }, validate: toStrTuple, - check: func(r *http.Request, args any) bool { - cookies := r.CookiesNamed(args.(StrTuple).First) - for _, cookie := range cookies { - if cookie.Value == args.(StrTuple).Second { - return true + builder: func(args any) CheckFunc { + k, v := args.(*StrTuple).Unpack() + return func(cached Cache, r *http.Request) bool { + cookies := cached.GetCookies(r) + for _, cookie := range cookies { + if cookie.Name == k && + cookie.Value == v { + return true + } } + return false } - return false }, }, OnForm: { @@ -87,8 +100,11 @@ var checkers = map[string]struct { }, }, validate: toStrTuple, - check: func(r *http.Request, args any) bool { - return r.FormValue(args.(StrTuple).First) == args.(StrTuple).Second + builder: func(args any) CheckFunc { + k, v := args.(*StrTuple).Unpack() + return func(cached Cache, r *http.Request) bool { + return r.FormValue(k) == v + } }, }, OnPostForm: { @@ -100,8 +116,11 @@ var checkers = map[string]struct { }, }, validate: toStrTuple, - check: func(r *http.Request, args any) bool { - return r.PostFormValue(args.(StrTuple).First) == args.(StrTuple).Second + builder: func(args any) CheckFunc { + k, v := args.(*StrTuple).Unpack() + return func(cached Cache, r *http.Request) bool { + return r.PostFormValue(k) == v + } }, }, OnMethod: { @@ -112,8 +131,11 @@ var checkers = map[string]struct { }, }, validate: validateMethod, - check: func(r *http.Request, method any) bool { - return r.Method == method.(string) + builder: func(args any) CheckFunc { + method := args.(string) + return func(cached Cache, r *http.Request) bool { + return r.Method == method + } }, }, OnPath: { @@ -127,12 +149,15 @@ var checkers = map[string]struct { }, }, validate: validateURLPath, - check: func(r *http.Request, globPath any) bool { - reqPath := r.URL.Path - if len(reqPath) > 0 && reqPath[0] != '/' { - reqPath = "/" + reqPath + builder: func(args any) CheckFunc { + pat := args.(string) + return func(cached Cache, r *http.Request) bool { + reqPath := r.URL.Path + if len(reqPath) > 0 && reqPath[0] != '/' { + reqPath = "/" + reqPath + } + return strutils.GlobMatch(pat, reqPath) } - return strutils.GlobMatch(globPath.(string), reqPath) }, }, OnRemote: { @@ -143,16 +168,31 @@ var checkers = map[string]struct { }, }, validate: validateCIDR, - check: func(r *http.Request, cidr any) bool { - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - host = r.RemoteAddr + builder: func(args any) CheckFunc { + cidr := args.(types.CIDR) + return func(cached Cache, r *http.Request) bool { + ip := cached.GetRemoteIP(r) + if ip == nil { + return false + } + return cidr.Contains(ip) } - ip := net.ParseIP(host) - if ip == nil { - return false + }, + }, + OnBasicAuth: { + help: Help{ + command: OnBasicAuth, + args: map[string]string{ + "username": "the username", + "password": "the password encrypted with bcrypt", + }, + }, + validate: validateUserBCryptPassword, + builder: func(args any) CheckFunc { + cred := args.(*HashedCrendentials) + return func(cached Cache, r *http.Request) bool { + return cred.Match(cached.GetBasicAuth(r)) } - return cidr.(*net.IPNet).Contains(ip) }, }, } @@ -162,7 +202,7 @@ func (on *RuleOn) Parse(v string) error { on.raw = v lines := strutils.SplitLine(v) - checks := make(Checkers, 0, len(lines)) + checkAnd := make(CheckMatchAll, 0, len(lines)) errs := E.NewBuilder("rule.on syntax errors") for i, line := range lines { @@ -174,10 +214,10 @@ func (on *RuleOn) Parse(v string) error { errs.Add(err.Subjectf("line %d", i+1)) continue } - checks = append(checks, parsed.matchOne()) + checkAnd = append(checkAnd, parsed) } - on.check = checks.matchAll() + on.checker = checkAnd return errs.Error() } @@ -185,28 +225,28 @@ func (on *RuleOn) String() string { return on.raw } -func (on *RuleOn) MarshalJSON() ([]byte, error) { - return []byte("\"" + on.String() + "\""), nil +func (on *RuleOn) MarshalText() ([]byte, error) { + return []byte(on.String()), nil } -func parseOn(line string) (Checkers, E.Error) { +func parseOn(line string) (Checker, E.Error) { ors := strutils.SplitRune(line, '|') if len(ors) > 1 { errs := E.NewBuilder("rule.on syntax errors") - checks := make([]CheckFulfill, len(ors)) + checkOr := make(CheckMatchSingle, len(ors)) for i, or := range ors { curCheckers, err := parseOn(or) if err != nil { errs.Add(err) continue } - checks[i] = curCheckers[0] + checkOr[i] = curCheckers.(CheckFunc) } if err := errs.Error(); err != nil { return nil, err } - return checks, nil + return checkOr, nil } subject, args, err := parse(line) @@ -224,31 +264,5 @@ func parseOn(line string) (Checkers, E.Error) { return nil, err.Subject(subject).Withf("%s", checker.help.String()) } - return Checkers{ - func(r *http.Request) bool { - return checker.check(r, validArgs) - }, - }, nil -} - -func (checkers Checkers) matchOne() CheckFulfill { - return func(r *http.Request) bool { - for _, checker := range checkers { - if checker(r) { - return true - } - } - return false - } -} - -func (checkers Checkers) matchAll() CheckFulfill { - return func(r *http.Request) bool { - for _, checker := range checkers { - if !checker(r) { - return false - } - } - return true - } + return checker.builder(validArgs), nil } diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go new file mode 100644 index 0000000..acbe384 --- /dev/null +++ b/internal/route/rules/on_test.go @@ -0,0 +1,89 @@ +package rules + +import ( + "testing" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestParseOn(t *testing.T) { + tests := []struct { + name string + input string + wantErr E.Error + }{ + // header + { + name: "header_valid", + input: "header Connection Upgrade", + wantErr: nil, + }, + { + name: "header_invalid", + input: "header Connection", + wantErr: ErrInvalidArguments, + }, + // query + { + name: "query_valid", + input: "query key value", + wantErr: nil, + }, + { + name: "query_invalid", + input: "query key", + wantErr: ErrInvalidArguments, + }, + // method + { + name: "method_valid", + input: "method GET", + wantErr: nil, + }, + { + name: "method_invalid", + input: "method", + wantErr: ErrInvalidArguments, + }, + // path + { + name: "path_valid", + input: "path /home", + wantErr: nil, + }, + { + name: "path_invalid", + input: "path", + wantErr: ErrInvalidArguments, + }, + // remote + { + name: "remote_valid", + input: "remote 127.0.0.1", + wantErr: nil, + }, + { + name: "remote_invalid", + input: "remote", + wantErr: ErrInvalidArguments, + }, + { + name: "unknown_target", + input: "unknown", + wantErr: ErrInvalidOnTarget, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + on := &RuleOn{} + err := on.Parse(tt.input) + if tt.wantErr != nil { + ExpectError(t, tt.wantErr, err) + } else { + ExpectNoError(t, err) + } + }) + } +} diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go index 653c150..7ebc689 100644 --- a/internal/route/rules/parser.go +++ b/internal/route/rules/parser.go @@ -14,6 +14,7 @@ var escapedChars = map[rune]rune{ '\'': '\'', '"': '"', '\\': '\\', + '$': '$', ' ': ' ', } diff --git a/internal/route/rules/parser_test.go b/internal/route/rules/parser_test.go index b560ef4..f43f284 100644 --- a/internal/route/rules/parser_test.go +++ b/internal/route/rules/parser_test.go @@ -94,3 +94,13 @@ func TestParser(t *testing.T) { } }) } + +func BenchmarkParser(b *testing.B) { + const input = `error 403 "Forbidden "foo" "bar""\ baz` + for range b.N { + _, _, err := parse(input) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index 1908eb1..3b5bf28 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -1,6 +1,7 @@ package rules import ( + "encoding/json" "net/http" ) @@ -26,7 +27,7 @@ type ( on: method POST | method PUT do: error 403 Forbidden */ - Rules []Rule + Rules []*Rule /* Rule is a rule for a reverse proxy. It do `Do` when `On` matches. @@ -52,52 +53,72 @@ type ( // if no rule matches, the default rule is executed // if no rule matches and default rule is not set, // the request is passed to the upstream. -func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { - var ( - defaultRule Rule - defaultRuleIndex int - ) +func (rules Rules) BuildHandler(caller string, up http.Handler) http.HandlerFunc { + var defaultRule *Rule + nonDefaultRules := make(Rules, 0, len(rules)) for i, rule := range rules { if rule.Name == "default" { defaultRule = rule - defaultRuleIndex = i + nonDefaultRules = append(nonDefaultRules, rules[:i]...) + nonDefaultRules = append(nonDefaultRules, rules[i+1:]...) break } } - rules = append(rules[:defaultRuleIndex], rules[defaultRuleIndex+1:]...) - - // free allocated empty slices - // before encapsulating them into the handlerFunc. if len(rules) == 0 { if defaultRule.Do.isBypass() { return up.ServeHTTP } - rules = []Rule{} + return func(w http.ResponseWriter, r *http.Request) { + cache := NewCache() + defer cache.Release() + if defaultRule.Do.exec.Handle(cache, w, r) { + up.ServeHTTP(w, r) + } + } } return func(w http.ResponseWriter, r *http.Request) { - hasMatch := false - for _, rule := range rules { - if rule.On.check(r) { + cache := NewCache() + defer cache.Release() + + for _, rule := range nonDefaultRules { + if rule.Check(cache, r) { if rule.Do.isBypass() { up.ServeHTTP(w, r) return } - rule.Do.exec.HandlerFunc(w, r) - if !rule.Do.exec.proceed { + if !rule.Handle(cache, w, r) { return } - hasMatch = true } } - if hasMatch || defaultRule.Do.isBypass() { + // bypass or proceed + if defaultRule.Do.isBypass() || defaultRule.Handle(cache, w, r) { up.ServeHTTP(w, r) - return } - - defaultRule.Do.exec.HandlerFunc(w, r) } } + +func (rules Rules) MarshalJSON() ([]byte, error) { + names := make([]string, len(rules)) + for i, rule := range rules { + names[i] = rule.Name + } + return json.Marshal(names) +} + +func (rule *Rule) String() string { + return rule.Name +} + +func (rule *Rule) Check(cached Cache, r *http.Request) bool { + return rule.On.checker.Check(cached, r) +} + +func (rule *Rule) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { + proceed = rule.Do.exec.Handle(cached, w, r) + return +} diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go index dd2b662..e5bd7af 100644 --- a/internal/route/rules/rules_test.go +++ b/internal/route/rules/rules_test.go @@ -3,249 +3,44 @@ package rules import ( "testing" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestParseSubjectArgs(t *testing.T) { - t.Run("basic", func(t *testing.T) { - subject, args, err := parse("rewrite / /foo/bar") - ExpectNoError(t, err) - ExpectEqual(t, subject, "rewrite") - ExpectDeepEqual(t, args, []string{"/", "/foo/bar"}) - }) - t.Run("with quotes", func(t *testing.T) { - subject, args, err := parse(`error 403 "Forbidden 'foo' 'bar'."`) - ExpectNoError(t, err) - ExpectEqual(t, subject, "error") - ExpectDeepEqual(t, args, []string{"403", "Forbidden 'foo' 'bar'."}) - }) - t.Run("with escaped", func(t *testing.T) { - subject, args, err := parse(`error 403 Forbidden\ \"foo\"\ \"bar\".`) - ExpectNoError(t, err) - ExpectEqual(t, subject, "error") - ExpectDeepEqual(t, args, []string{"403", "Forbidden \"foo\" \"bar\"."}) - }) -} - -func TestParseCommands(t *testing.T) { - tests := []struct { - name string - input string - wantErr error - }{ - // bypass tests - { - name: "bypass_valid", - input: "bypass", - wantErr: nil, - }, - { - name: "bypass_invalid_with_args", - input: "bypass /", - wantErr: ErrInvalidArguments, - }, - // rewrite tests - { - name: "rewrite_valid", - input: "rewrite / /foo/bar", - wantErr: nil, - }, - { - name: "rewrite_missing_target", - input: "rewrite /", - wantErr: ErrInvalidArguments, - }, - { - name: "rewrite_too_many_args", - input: "rewrite / / /", - wantErr: ErrInvalidArguments, - }, - { - name: "rewrite_no_leading_slash", - input: "rewrite abc /", - wantErr: ErrInvalidArguments, - }, - // serve tests - { - name: "serve_valid", - input: "serve /var/www", - wantErr: nil, - }, - { - name: "serve_missing_path", - input: "serve ", - wantErr: ErrInvalidArguments, - }, - { - name: "serve_too_many_args", - input: "serve / / /", - wantErr: ErrInvalidArguments, - }, - // redirect tests - { - name: "redirect_valid", - input: "redirect /", - wantErr: nil, - }, - { - name: "redirect_too_many_args", - input: "redirect / /", - wantErr: ErrInvalidArguments, - }, - // error directive tests - { - name: "error_valid", - input: "error 404 Not\\ Found", - wantErr: nil, - }, - { - name: "error_missing_status_code", - input: "error Not\\ Found", - wantErr: ErrInvalidArguments, - }, - { - name: "error_too_many_args", - input: "error 404 Not\\ Found extra", - wantErr: ErrInvalidArguments, - }, - { - name: "error_no_escaped_space", - input: "error 404 Not Found", - wantErr: ErrInvalidArguments, - }, - { - name: "error_invalid_status_code", - input: "error 123 abc", - wantErr: ErrInvalidArguments, - }, - // proxy directive tests - { - name: "proxy_valid", - input: "proxy localhost:8080", - wantErr: nil, - }, - { - name: "proxy_missing_target", - input: "proxy", - wantErr: ErrInvalidArguments, - }, - { - name: "proxy_too_many_args", - input: "proxy localhost:8080 extra", - wantErr: ErrInvalidArguments, - }, - { - name: "proxy_invalid_url", - input: "proxy :invalid_url", - wantErr: ErrInvalidArguments, - }, - // unknown directive test - { - name: "unknown_directive", - input: "unknown /", - wantErr: ErrUnknownDirective, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := Command{} - err := cmd.Parse(tt.input) - if tt.wantErr != nil { - ExpectError(t, tt.wantErr, err) - } else { - ExpectNoError(t, err) - } - }) - } -} - -func TestParseOn(t *testing.T) { - tests := []struct { - name string - input string - wantErr E.Error - }{ - // header - { - name: "header_valid", - input: "header Connection Upgrade", - wantErr: nil, - }, - { - name: "header_invalid", - input: "header Connection", - wantErr: ErrInvalidArguments, - }, - // query - { - name: "query_valid", - input: "query key value", - wantErr: nil, - }, - { - name: "query_invalid", - input: "query key", - wantErr: ErrInvalidArguments, - }, - // method - { - name: "method_valid", - input: "method GET", - wantErr: nil, - }, - { - name: "method_invalid", - input: "method", - wantErr: ErrInvalidArguments, - }, - // path - { - name: "path_valid", - input: "path /home", - wantErr: nil, - }, - { - name: "path_invalid", - input: "path", - wantErr: ErrInvalidArguments, - }, - // remote - { - name: "remote_valid", - input: "remote 127.0.0.1", - wantErr: nil, - }, - { - name: "remote_invalid", - input: "remote", - wantErr: ErrInvalidArguments, - }, - { - name: "unknown_target", - input: "unknown", - wantErr: ErrInvalidOnTarget, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - on := &RuleOn{} - err := on.Parse(tt.input) - if tt.wantErr != nil { - ExpectError(t, tt.wantErr, err) - } else { - ExpectNoError(t, err) - } - }) - } -} - func TestParseRule(t *testing.T) { - // test := map[string]any{ - // "name": "test", - // "on": "method GET", - // "do": "bypass", - // } + test := []map[string]any{ + { + "name": "test", + "on": "method POST", + "do": "error 403 Forbidden", + }, + { + "name": "auth", + "on": `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`, + "do": "bypass", + }, + { + "name": "default", + "do": "require_basic_auth any_realm", + }, + } + + var rules struct { + Rules Rules + } + err := utils.Deserialize(utils.SerializedObject{"rules": test}, &rules) + ExpectNoError(t, err) + ExpectEqual(t, len(rules.Rules), len(test)) + ExpectEqual(t, rules.Rules[0].Name, "test") + ExpectEqual(t, rules.Rules[0].On.String(), "method POST") + ExpectEqual(t, rules.Rules[0].Do.String(), "error 403 Forbidden") + + ExpectEqual(t, rules.Rules[1].Name, "auth") + ExpectEqual(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`) + ExpectEqual(t, rules.Rules[1].Do.String(), "bypass") + + ExpectEqual(t, rules.Rules[2].Name, "default") + ExpectEqual(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm") } + +// TODO: real tests. diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index 55621fb..da728f9 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -1,6 +1,7 @@ package rules import ( + "fmt" "os" "path" "strings" @@ -11,19 +12,31 @@ import ( ) type ( - ValidateFunc func(args []string) (any, E.Error) - StrTuple struct { - First, Second string + ValidateFunc func(args []string) (any, E.Error) + Tuple[T1, T2 any] struct { + First T1 + Second T2 } + StrTuple = Tuple[string, string] ) +func (t *Tuple[T1, T2]) Unpack() (T1, T2) { + return t.First, t.Second +} + +func (t *Tuple[T1, T2]) String() string { + return fmt.Sprintf("%v:%v", t.First, t.Second) +} + +// toStrTuple returns *StrTuple. func toStrTuple(args []string) (any, E.Error) { if len(args) != 2 { return nil, ErrExpectTwoArgs } - return StrTuple{args[0], args[1]}, nil + return &StrTuple{args[0], args[1]}, nil } +// validateURL returns types.URL with the URL validated. func validateURL(args []string) (any, E.Error) { if len(args) != 1 { return nil, ErrExpectOneArg @@ -35,6 +48,7 @@ func validateURL(args []string) (any, E.Error) { return u, nil } +// validateAbsoluteURL returns types.URL with the URL validated. func validateAbsoluteURL(args []string) (any, E.Error) { if len(args) != 1 { return nil, ErrExpectOneArg @@ -52,6 +66,7 @@ func validateAbsoluteURL(args []string) (any, E.Error) { return u, nil } +// validateCIDR returns types.CIDR with the CIDR validated. func validateCIDR(args []string) (any, E.Error) { if len(args) != 1 { return nil, ErrExpectOneArg @@ -66,6 +81,7 @@ func validateCIDR(args []string) (any, E.Error) { return cidr, nil } +// validateURLPath returns string with the path validated. func validateURLPath(args []string) (any, E.Error) { if len(args) != 1 { return nil, ErrExpectOneArg @@ -86,6 +102,7 @@ func validateURLPath(args []string) (any, E.Error) { return p, nil } +// validateURLPaths returns []string with each element validated. func validateURLPaths(paths []string) (any, E.Error) { errs := E.NewBuilder("invalid url paths") for i, p := range paths { @@ -102,6 +119,7 @@ func validateURLPaths(paths []string) (any, E.Error) { return paths, nil } +// validateFSPath returns string with the path validated. func validateFSPath(args []string) (any, E.Error) { if len(args) != 1 { return nil, ErrExpectOneArg @@ -113,6 +131,7 @@ func validateFSPath(args []string) (any, E.Error) { return p, nil } +// validateMethod returns string with the method validated. func validateMethod(args []string) (any, E.Error) { if len(args) != 1 { return nil, ErrExpectOneArg @@ -123,3 +142,31 @@ func validateMethod(args []string) (any, E.Error) { } return method, nil } + +// validateUserBCryptPassword returns *HashedCrendential with the password validated. +func validateUserBCryptPassword(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + return BCryptCrendentials(args[0], []byte(args[1])), nil +} + +// validateModField returns CommandHandler with the field validated. +func validateModField(mod FieldModifier, args []string) (CommandHandler, E.Error) { + setField, ok := modFields[args[0]] + if !ok { + return nil, ErrInvalidSetTarget.Subject(args[0]) + } + validArgs, err := setField.validate(args[1:]) + if err != nil { + return nil, err.Withf(setField.help.String()) + } + modder := setField.builder(validArgs) + switch mod { + case ModFieldAdd: + return modder.add, nil + case ModFieldRemove: + return modder.remove, nil + } + return modder.set, nil +}