From a98b2bb71acac6f1283fb0f0a739c18da7e545ad Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 8 Jan 2025 13:50:10 +0800 Subject: [PATCH] updated implementation of rules --- internal/net/http/methods.go | 20 + internal/net/http/status_code.go | 4 + internal/net/types/cidr.go | 5 + internal/route/rules/do.go | 164 ++++++++ internal/route/rules/errors.go | 14 + internal/route/rules/on.go | 166 ++++++++ internal/route/rules/parser.go | 78 ++++ internal/route/rules/rules.go | 110 ++++++ internal/route/{types => rules}/rules_test.go | 111 +++++- internal/route/rules/validate.go | 104 +++++ internal/route/types/raw_entry.go | 3 +- internal/route/types/rules.go | 367 ------------------ 12 files changed, 775 insertions(+), 371 deletions(-) create mode 100644 internal/net/http/methods.go create mode 100644 internal/route/rules/do.go create mode 100644 internal/route/rules/errors.go create mode 100644 internal/route/rules/on.go create mode 100644 internal/route/rules/parser.go create mode 100644 internal/route/rules/rules.go rename internal/route/{types => rules}/rules_test.go (60%) create mode 100644 internal/route/rules/validate.go delete mode 100644 internal/route/types/rules.go diff --git a/internal/net/http/methods.go b/internal/net/http/methods.go new file mode 100644 index 0000000..a46923d --- /dev/null +++ b/internal/net/http/methods.go @@ -0,0 +1,20 @@ +package http + +import "net/http" + +var validMethods = map[string]struct{}{ + http.MethodGet: {}, + http.MethodHead: {}, + http.MethodPost: {}, + http.MethodPut: {}, + http.MethodPatch: {}, + http.MethodDelete: {}, + http.MethodConnect: {}, + http.MethodOptions: {}, + http.MethodTrace: {}, +} + +func IsMethodValid(method string) bool { + _, ok := validMethods[method] + return ok +} diff --git a/internal/net/http/status_code.go b/internal/net/http/status_code.go index db8002c..8235805 100644 --- a/internal/net/http/status_code.go +++ b/internal/net/http/status_code.go @@ -5,3 +5,7 @@ import "net/http" func IsSuccess(status int) bool { return status >= http.StatusOK && status < http.StatusMultipleChoices } + +func IsStatusCodeValid(status int) bool { + return http.StatusText(status) != "" +} diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index 1aa00b9..67ca297 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -8,6 +8,11 @@ import ( //nolint:recvcheck type CIDR net.IPNet +func ParseCIDR(v string) (cidr CIDR, err error) { + err = cidr.Parse(v) + return +} + func (cidr *CIDR) Parse(v string) error { if !strings.Contains(v, "/") { v += "/32" // single IP diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go new file mode 100644 index 0000000..4319ff4 --- /dev/null +++ b/internal/route/rules/do.go @@ -0,0 +1,164 @@ +package rules + +import ( + "net/http" + "path" + "strconv" + "strings" + + E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/net/types" +) + +type ( + Command struct { + raw string + CommandExecutor + } + CommandExecutor struct { + http.HandlerFunc + proceed bool + } +) + +const ( + CommandRewrite = "rewrite" + CommandServe = "serve" + CommandProxy = "proxy" + CommandRedirect = "redirect" + CommandError = "error" + CommandBypass = "bypass" +) + +var commands = map[string]struct { + validate ValidateFunc + build func(args any) CommandExecutor +}{ + CommandRewrite: { + validate: func(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + return validateURLPaths(args) + }, + build: func(args any) CommandExecutor { + a := args.([]string) + orig, repl := a[0], a[1] + return CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + if len(r.URL.Path) > 0 && r.URL.Path[0] != '/' { + r.URL.Path = "/" + r.URL.Path + } + r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1) + r.URL.RawPath = r.URL.EscapedPath() + r.RequestURI = r.URL.String() + }, + proceed: true, + } + }, + }, + CommandServe: { + validate: validateFSPath, + build: func(args any) CommandExecutor { + root := args.(string) + return CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) + }, + proceed: false, + } + }, + }, + CommandRedirect: { + validate: validateURL, + build: func(args any) CommandExecutor { + target := args.(types.URL).String() + return CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target, http.StatusTemporaryRedirect) + }, + proceed: false, + } + }, + }, + CommandError: { + validate: func(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + codeStr, text := args[0], args[1] + code, err := strconv.Atoi(codeStr) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + if !gphttp.IsStatusCodeValid(code) { + return nil, ErrInvalidArguments.Subject(codeStr) + } + return []any{code, text}, nil + }, + build: func(args any) CommandExecutor { + a := args.([]any) + code, text := a[0].(int), a[1].(string) + return CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, text, code) + }, + proceed: false, + } + }, + }, + CommandProxy: { + validate: validateURL, + build: func(args any) CommandExecutor { + target := args.(types.URL) + if target.Scheme == "" { + target.Scheme = "http" + } + rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) + return CommandExecutor{ + HandlerFunc: rp.ServeHTTP, + proceed: false, + } + }, + }, +} + +func (cmd *Command) Parse(v string) error { + cmd.raw = v + directive, args, err := parse(v) + if err != nil { + return err + } + + if directive == CommandBypass { + if len(args) != 0 { + return ErrInvalidArguments.Subject(directive) + } + return nil + } + + builder, ok := commands[directive] + if !ok { + return ErrUnknownDirective.Subject(directive) + } + validArgs, err := builder.validate(args) + if err != nil { + return err.Subject(directive) + } + cmd.CommandExecutor = builder.build(validArgs) + return nil +} + +func (cmd *Command) isBypass() bool { + return cmd.HandlerFunc == nil +} + +func (cmd *Command) String() string { + return cmd.raw +} + +func (cmd *Command) MarshalJSON() ([]byte, error) { + return []byte("\"" + cmd.String() + "\""), nil +} diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go new file mode 100644 index 0000000..d8484df --- /dev/null +++ b/internal/route/rules/errors.go @@ -0,0 +1,14 @@ +package rules + +import E "github.com/yusing/go-proxy/internal/error" + +var ( + ErrUnterminatedQuotes = E.New("unterminated quotes") + ErrUnsupportedEscapeChar = E.New("unsupported escape char") + ErrUnknownDirective = E.New("unknown directive") + ErrInvalidArguments = E.New("invalid arguments") + ErrInvalidOnTarget = E.New("invalid `rule.on` target") + + ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") + ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") +) diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go new file mode 100644 index 0000000..8b8063e --- /dev/null +++ b/internal/route/rules/on.go @@ -0,0 +1,166 @@ +package rules + +import ( + "net" + "net/http" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type ( + RuleOn struct { + raw string + check CheckFulfill + } + CheckFulfill func(r *http.Request) bool + Checkers []CheckFulfill +) + +const ( + OnHeader = "header" + OnQuery = "query" + OnMethod = "method" + OnPath = "path" + OnRemote = "remote" +) + +var checkers = map[string]struct { + validate ValidateFunc + check func(r *http.Request, args any) bool +}{ + OnHeader: { // header + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second + }, + }, + OnQuery: { // query + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second + }, + }, + OnMethod: { // method + validate: validateMethod, + check: func(r *http.Request, method any) bool { + return r.Method == method.(string) + }, + }, + OnPath: { // path + validate: validateURLPath, + check: func(r *http.Request, globPath any) bool { + reqPath := r.URL.Path + if len(reqPath) > 0 && reqPath[0] != '/' { + reqPath = "/" + reqPath + } + return strutils.GlobMatch(globPath.(string), reqPath) + }, + }, + OnRemote: { // remote + validate: validateCIDR, + check: func(r *http.Request, cidr any) bool { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return cidr.(*net.IPNet).Contains(ip) + }, + }, +} + +func (on *RuleOn) Parse(v string) error { + on.raw = v + + lines := strutils.SplitLine(v) + checks := make(Checkers, 0, len(lines)) + + errs := E.NewBuilder("rule.on syntax errors") + for i, line := range lines { + parsed, err := parseOn(line) + if err != nil { + errs.Add(err.Subjectf("line %d", i+1)) + continue + } + checks = append(checks, parsed.matchOne()) + } + + on.check = checks.matchAll() + return errs.Error() +} + +func (on *RuleOn) String() string { + return on.raw +} + +func (on *RuleOn) MarshalJSON() ([]byte, error) { + return []byte("\"" + on.String() + "\""), nil +} + +func parseOn(line string) (Checkers, E.Error) { + ors := strutils.SplitRune(line, '|') + + if len(ors) > 1 { + errs := E.NewBuilder("rule.on syntax errors") + checks := make([]CheckFulfill, len(ors)) + for i, or := range ors { + curCheckers, err := parseOn(or) + if err != nil { + errs.Add(err) + continue + } + checks[i] = curCheckers[0] + } + if err := errs.Error(); err != nil { + return nil, err + } + return checks, nil + } + + subject, args, err := parse(line) + if err != nil { + return nil, err + } + + checker, ok := checkers[subject] + if !ok { + return nil, ErrInvalidOnTarget.Subject(subject) + } + + validArgs, err := checker.validate(args) + if err != nil { + return nil, err.Subject(subject) + } + + return Checkers{ + func(r *http.Request) bool { + return checker.check(r, validArgs) + }, + }, nil +} + +func (checkers Checkers) matchOne() CheckFulfill { + return func(r *http.Request) bool { + for _, checker := range checkers { + if checker(r) { + return true + } + } + return false + } +} + +func (checkers Checkers) matchAll() CheckFulfill { + return func(r *http.Request) bool { + for _, checker := range checkers { + if !checker(r) { + return false + } + } + return true + } +} diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go new file mode 100644 index 0000000..2840ac8 --- /dev/null +++ b/internal/route/rules/parser.go @@ -0,0 +1,78 @@ +package rules + +import ( + "strings" + + E "github.com/yusing/go-proxy/internal/error" +) + +var escapedChars = map[rune]rune{ + 'n': '\n', + 't': '\t', + 'r': '\r', + '\'': '\'', + '"': '"', + ' ': ' ', +} + +// parse expression to subject and args +// with support for quotes and escaped chars, e.g. +// +// error 403 "Forbidden 'foo' 'bar'" +// error 403 Forbidden\ \"foo\"\ \"bar\". +func parse(v string) (subject string, args []string, err E.Error) { + v = strings.TrimSpace(v) + var buf strings.Builder + escaped := false + quotes := make([]rune, 0, 4) + flush := func() { + if subject == "" { + subject = buf.String() + } else { + args = append(args, buf.String()) + } + buf.Reset() + } + for _, r := range v { + if escaped { + if ch, ok := escapedChars[r]; ok { + buf.WriteRune(ch) + } else { + err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r) + return + } + escaped = false + continue + } + switch r { + case '\\': + escaped = true + continue + case '"', '\'': + switch { + case len(quotes) > 0 && quotes[len(quotes)-1] == r: + quotes = quotes[:len(quotes)-1] + if len(quotes) == 0 { + flush() + } else { + buf.WriteRune(r) + } + case len(quotes) == 0: + quotes = append(quotes, r) + default: + buf.WriteRune(r) + } + case ' ': + flush() + default: + buf.WriteRune(r) + } + } + + if len(quotes) > 0 { + err = ErrUnterminatedQuotes + } else { + flush() + } + return +} diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go new file mode 100644 index 0000000..54c267b --- /dev/null +++ b/internal/route/rules/rules.go @@ -0,0 +1,110 @@ +package rules + +import ( + "net/http" + + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" +) + +type ( + /* + Example: + + proxy.app1.rules: | + - name: default + do: | + rewrite / /index.html + serve /var/www/goaccess + - name: ws + on: | + header Connection Upgrade + header Upgrade websocket + do: bypass + + proxy.app2.rules: | + - name: default + do: bypass + - name: block POST and PUT + on: method POST | method PUT + do: error 403 Forbidden + */ + Rules []Rule + /* + Rule is a rule for a reverse proxy. + It do `Do` when `On` matches. + + A rule can have multiple lines of on. + + All lines of on must match, + but each line can have multiple checks that + one match means this line is matched. + */ + Rule struct { + Name string `json:"name" validate:"required,unique"` + On RuleOn `json:"on"` + Do Command `json:"do"` + } +) + +// BuildHandler returns a http.HandlerFunc that implements the rules. +// +// Bypass rules are executed first +// if a bypass rule matches, +// the request is passed to the upstream and no more rules are executed. +// +// Other rules are executed later +// if no rule matches, the default rule is executed +// if no rule matches and default rule is not set, +// the request is passed to the upstream. +func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc { + // move bypass rules to the front. + bypassRules := make(Rules, 0, len(rules)) + otherRules := make(Rules, 0, len(rules)) + + var defaultRule Rule + + for _, rule := range rules { + switch { + case rule.Do.isBypass(): + bypassRules = append(bypassRules, rule) + case rule.Name == "default": + defaultRule = rule + default: + otherRules = append(otherRules, rule) + } + } + + // free allocated empty slices + // before encapsulating them into the handlerFunc. + if len(bypassRules) == 0 { + bypassRules = []Rule{} + } + if len(otherRules) == 0 { + otherRules = []Rule{defaultRule} + } + + return func(w http.ResponseWriter, r *http.Request) { + for _, rule := range bypassRules { + if rule.On.check(r) { + up.ServeHTTP(w, r) + return + } + } + hasMatch := false + for _, rule := range otherRules { + if rule.On.check(r) { + hasMatch = true + rule.Do.HandlerFunc(w, r) + if !rule.Do.proceed { + return + } + } + } + if hasMatch || defaultRule.Do.isBypass() { + up.ServeHTTP(w, r) + return + } + + defaultRule.Do.HandlerFunc(w, r) + } +} diff --git a/internal/route/types/rules_test.go b/internal/route/rules/rules_test.go similarity index 60% rename from internal/route/types/rules_test.go rename to internal/route/rules/rules_test.go index 9f8a414..dd2b662 100644 --- a/internal/route/types/rules_test.go +++ b/internal/route/rules/rules_test.go @@ -1,13 +1,14 @@ -package types +package rules import ( "testing" + E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestParseSubjectArgs(t *testing.T) { - t.Run("without quotes", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { subject, args, err := parse("rewrite / /foo/bar") ExpectNoError(t, err) ExpectEqual(t, subject, "rewrite") @@ -60,6 +61,11 @@ func TestParseCommands(t *testing.T) { input: "rewrite / / /", wantErr: ErrInvalidArguments, }, + { + name: "rewrite_no_leading_slash", + input: "rewrite abc /", + wantErr: ErrInvalidArguments, + }, // serve tests { name: "serve_valid", @@ -104,10 +110,15 @@ func TestParseCommands(t *testing.T) { wantErr: ErrInvalidArguments, }, { - name: "error_unescaped_space", + name: "error_no_escaped_space", input: "error 404 Not Found", wantErr: ErrInvalidArguments, }, + { + name: "error_invalid_status_code", + input: "error 123 abc", + wantErr: ErrInvalidArguments, + }, // proxy directive tests { name: "proxy_valid", @@ -124,6 +135,11 @@ func TestParseCommands(t *testing.T) { input: "proxy localhost:8080 extra", wantErr: ErrInvalidArguments, }, + { + name: "proxy_invalid_url", + input: "proxy :invalid_url", + wantErr: ErrInvalidArguments, + }, // unknown directive test { name: "unknown_directive", @@ -144,3 +160,92 @@ func TestParseCommands(t *testing.T) { }) } } + +func TestParseOn(t *testing.T) { + tests := []struct { + name string + input string + wantErr E.Error + }{ + // header + { + name: "header_valid", + input: "header Connection Upgrade", + wantErr: nil, + }, + { + name: "header_invalid", + input: "header Connection", + wantErr: ErrInvalidArguments, + }, + // query + { + name: "query_valid", + input: "query key value", + wantErr: nil, + }, + { + name: "query_invalid", + input: "query key", + wantErr: ErrInvalidArguments, + }, + // method + { + name: "method_valid", + input: "method GET", + wantErr: nil, + }, + { + name: "method_invalid", + input: "method", + wantErr: ErrInvalidArguments, + }, + // path + { + name: "path_valid", + input: "path /home", + wantErr: nil, + }, + { + name: "path_invalid", + input: "path", + wantErr: ErrInvalidArguments, + }, + // remote + { + name: "remote_valid", + input: "remote 127.0.0.1", + wantErr: nil, + }, + { + name: "remote_invalid", + input: "remote", + wantErr: ErrInvalidArguments, + }, + { + name: "unknown_target", + input: "unknown", + wantErr: ErrInvalidOnTarget, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + on := &RuleOn{} + err := on.Parse(tt.input) + if tt.wantErr != nil { + ExpectError(t, tt.wantErr, err) + } else { + ExpectNoError(t, err) + } + }) + } +} + +func TestParseRule(t *testing.T) { + // test := map[string]any{ + // "name": "test", + // "on": "method GET", + // "do": "bypass", + // } +} diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go new file mode 100644 index 0000000..649d530 --- /dev/null +++ b/internal/route/rules/validate.go @@ -0,0 +1,104 @@ +package rules + +import ( + "os" + "path" + "strings" + + E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/types" +) + +type ( + ValidateFunc func(args []string) (any, E.Error) + StrTuple struct { + First, Second string + } +) + +func toStrTuple(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + return StrTuple{args[0], args[1]}, nil +} + +func validateURL(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + u, err := types.ParseURL(args[0]) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + return u, nil +} + +func validateCIDR(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + if !strings.Contains(args[0], "/") { + args[0] += "/32" + } + cidr, err := types.ParseCIDR(args[0]) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + return cidr, nil +} + +func validateURLPath(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + p := args[0] + p, _, _ = strings.Cut(p, "#") + p = path.Clean(p) + if len(p) == 0 { + return "/", nil + } + if p[0] != '/' { + return nil, ErrInvalidArguments.Withf("must start with /") + } + return p, nil +} + +func validateURLPaths(paths []string) (any, E.Error) { + errs := E.NewBuilder("invalid url paths") + for i, p := range paths { + val, err := validateURLPath([]string{p}) + if err != nil { + errs.Add(err.Subject(p)) + continue + } + paths[i] = val.(string) + } + if err := errs.Error(); err != nil { + return nil, err + } + return paths, nil +} + +func validateFSPath(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + p := path.Clean(args[0]) + if _, err := os.Stat(p); err != nil { + return nil, ErrInvalidArguments.With(err) + } + return p, nil +} + +func validateMethod(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + method := strings.ToUpper(args[0]) + if !gphttp.IsMethodValid(method) { + return nil, ErrInvalidArguments.Subject(method) + } + return method, nil +} diff --git a/internal/route/types/raw_entry.go b/internal/route/types/raw_entry.go index 92a77b0..2670c71 100644 --- a/internal/route/types/raw_entry.go +++ b/internal/route/types/raw_entry.go @@ -12,6 +12,7 @@ import ( "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/http/accesslog" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" + "github.com/yusing/go-proxy/internal/route/rules" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -30,7 +31,7 @@ type ( Port string `json:"port,omitempty"` NoTLSVerify bool `json:"no_tls_verify,omitempty"` PathPatterns []string `json:"path_patterns,omitempty"` - Rules Rules `json:"rules,omitempty"` + Rules rules.Rules `json:"rules,omitempty"` HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` LoadBalance *loadbalance.Config `json:"load_balance,omitempty"` Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"` diff --git a/internal/route/types/rules.go b/internal/route/types/rules.go deleted file mode 100644 index be6f7f7..0000000 --- a/internal/route/types/rules.go +++ /dev/null @@ -1,367 +0,0 @@ -package types - -import ( - "net/http" - "path" - "strconv" - "strings" - - E "github.com/yusing/go-proxy/internal/error" - gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/utils/strutils" -) - -type ( - Rules []Rule - Rule struct { - Name string `json:"name" validate:"required,unique"` - On RuleOn `json:"on"` - Do Command `json:"do"` - } - RuleOn struct { - raw string - checkers []CheckFulfill - } - Command struct { - raw string - CommandExecutor - } - CheckFulfill func(r *http.Request) bool - RequestObjectRetriever struct { - expectedArgs int - retrieve func(r *http.Request, args []string) string - equal func(v, want string) bool - } - CommandExecutor struct { - http.HandlerFunc - proceed bool - } - CommandBuilder struct { - expectedArgs int - build func(args []string) CommandExecutor - } -) - -/* -proxy.app1.rules: | - - name: default - do: | - rewrite / /index.html - serve /var/www/goaccess - - name: ws - on: | - header Connection upgrade - header Upgrade websocket - do: proxy $upstream_url -*/ - -var ( - ErrUnterminatedQuotes = E.New("unterminated quotes") - ErrUnsupportedEscapeChar = E.New("unsupported escape char") - ErrUnknownDirective = E.New("unknown directive") - ErrInvalidArguments = E.New("invalid arguments") - ErrInvalidCriteria = E.New("invalid criteria") - ErrInvalidCriteriaTarget = E.New("invalid criteria target") -) - -var retrievers = map[string]RequestObjectRetriever{ - "header": {1, func(r *http.Request, args []string) string { - return r.Header.Get(args[0]) - }, nil}, - "query": {1, func(r *http.Request, args []string) string { - return r.URL.Query().Get(args[0]) - }, nil}, - "method": {0, func(r *http.Request, _ []string) string { - return r.Method - }, nil}, - "path": {0, func(r *http.Request, _ []string) string { - return r.URL.Path - }, func(v, want string) bool { - return strutils.GlobMatch(want, v) - }}, - "remote": {0, func(r *http.Request, _ []string) string { - return r.RemoteAddr - }, nil}, -} - -var commands = map[string]CommandBuilder{ - "rewrite": {2, func(args []string) CommandExecutor { - orig, repl := args[0], args[1] - return CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1) - r.URL.RawPath = r.URL.EscapedPath() - r.RequestURI = r.URL.String() - }, - proceed: true, - } - }}, - "serve": {1, func(args []string) CommandExecutor { - return CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, path.Join(args[0], path.Clean(r.URL.Path))) - }, - proceed: false, - } - }}, - "redirect": {1, func(args []string) CommandExecutor { - target := args[0] - return CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, target, http.StatusTemporaryRedirect) - }, - proceed: false, - } - }}, - "error": {2, func(args []string) CommandExecutor { - codeStr, text := args[0], args[1] - code, err := strconv.Atoi(codeStr) - if err != nil { - code = http.StatusNotFound - } - return CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - http.Error(w, text, code) - }, - proceed: false, - } - }}, - "proxy": {1, func(args []string) CommandExecutor { - target := args[0] - return CommandExecutor{ - HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - r.URL.Scheme = "http" - r.URL.Host = target - r.URL.RawPath = r.URL.EscapedPath() - r.RequestURI = r.URL.String() - }, - proceed: true, - } - }}, -} - -var escapedChars = map[rune]rune{ - 'n': '\n', - 't': '\t', - 'r': '\r', - '\'': '\'', - '"': '"', - ' ': ' ', -} - -// BuildHandler returns a http.HandlerFunc that implements the rules. -// -// Bypass rules are executed first -// if a bypass rule matches, -// the request is passed to the upstream and no more rules are executed. -// -// Other rules are executed later -// if no rule matches, the default rule is executed -// if no rule matches and default rule is not set, -// the request is passed to the upstream. -func (rules Rules) BuildHandler(up *gphttp.ReverseProxy) http.HandlerFunc { - // move bypass rules to the front. - bypassRules := make(Rules, 0, len(rules)) - otherRules := make(Rules, 0, len(rules)) - - var defaultRule Rule - - for _, rule := range rules { - switch { - case rule.Do.isBypass(): - bypassRules = append(bypassRules, rule) - case rule.Name == "default": - defaultRule = rule - default: - otherRules = append(otherRules, rule) - } - } - - // free allocated empty slices - // before passing them to the handler. - if len(bypassRules) == 0 { - bypassRules = []Rule{} - } - if len(otherRules) == 0 { - otherRules = []Rule{defaultRule} - } - - return func(w http.ResponseWriter, r *http.Request) { - hasMatch := false - for _, rule := range bypassRules { - if rule.On.MatchAll(r) { - up.ServeHTTP(w, r) - return - } - } - for _, rule := range otherRules { - if rule.On.MatchAll(r) { - hasMatch = true - rule.Do.HandlerFunc(w, r) - if !rule.Do.proceed { - return - } - } - } - if hasMatch || defaultRule.Do.isBypass() { - up.ServeHTTP(w, r) - return - } - - defaultRule.Do.HandlerFunc(w, r) - if !defaultRule.Do.proceed { - return - } - } -} - -// parse line to subject and args -// with support for quotes and escaped chars, e.g. -// -// error 403 "Forbidden 'foo' 'bar'" -// error 403 Forbidden\ \"foo\"\ \"bar\". -func parse(v string) (subject string, args []string, err E.Error) { - v = strings.TrimSpace(v) - var buf strings.Builder - escaped := false - quotes := make([]rune, 0, 4) - flush := func() { - if subject == "" { - subject = buf.String() - } else { - args = append(args, buf.String()) - } - buf.Reset() - } - for _, r := range v { - if escaped { - if ch, ok := escapedChars[r]; ok { - buf.WriteRune(ch) - } else { - err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r) - return - } - escaped = false - continue - } - switch r { - case '\\': - escaped = true - continue - case '"', '\'': - switch { - case len(quotes) > 0 && quotes[len(quotes)-1] == r: - quotes = quotes[:len(quotes)-1] - if len(quotes) == 0 { - flush() - } else { - buf.WriteRune(r) - } - case len(quotes) == 0: - quotes = append(quotes, r) - default: - buf.WriteRune(r) - } - case ' ': - flush() - default: - buf.WriteRune(r) - } - } - - if len(quotes) > 0 { - err = ErrUnterminatedQuotes - } else { - flush() - } - return -} - -func (on *RuleOn) Parse(v string) E.Error { - lines := strutils.SplitLine(v) - on.checkers = make([]CheckFulfill, 0, len(lines)) - on.raw = v - - errs := E.NewBuilder("rule.on syntax errors") - for i, line := range lines { - subject, args, err := parse(line) - if err != nil { - errs.Add(err.Subjectf("line %d", i+1)) - continue - } - retriever, ok := retrievers[subject] - if !ok { - errs.Add(ErrInvalidCriteriaTarget.Subject(subject).Subjectf("line %d", i+1)) - continue - } - nArgs := retriever.expectedArgs - if len(args) != nArgs+1 { - errs.Add(ErrInvalidArguments.Subject(subject).Subjectf("line %d", i+1)) - continue - } - equal := retriever.equal - if equal == nil { - equal = func(a, b string) bool { - return a == b - } - } - on.checkers = append(on.checkers, func(r *http.Request) bool { - return equal(retriever.retrieve(r, args[:nArgs]), args[nArgs]) - }) - } - return errs.Error() -} - -func (on *RuleOn) MatchAll(r *http.Request) bool { - for _, match := range on.checkers { - if !match(r) { - return false - } - } - return true -} - -func (cmd *Command) Parse(v string) E.Error { - cmd.raw = v - directive, args, err := parse(v) - if err != nil { - return err - } - - if directive == "bypass" { - if len(args) != 0 { - return ErrInvalidArguments.Subject(directive) - } - return nil - } - - builder, ok := commands[directive] - if !ok { - return ErrUnknownDirective.Subject(directive) - } - if len(args) != builder.expectedArgs { - return ErrInvalidArguments.Subject(directive) - } - cmd.CommandExecutor = builder.build(args) - return nil -} - -func (cmd *Command) isBypass() bool { - return cmd.HandlerFunc == nil -} - -func (on *RuleOn) String() string { - return on.raw -} - -func (on *RuleOn) MarshalJSON() ([]byte, error) { - return []byte("\"" + on.String() + "\""), nil -} - -func (cmd *Command) String() string { - return cmd.raw -} - -func (cmd *Command) MarshalJSON() ([]byte, error) { - return []byte("\"" + cmd.String() + "\""), nil -}