diff --git a/internal/route/http.go b/internal/route/http.go index 96fb3f0..e5ce760 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -139,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error { } if len(r.Raw.Rules) > 0 { - r.handler = r.Raw.Rules.BuildHandler(r.rp) + r.handler = r.Raw.Rules.BuildHandler(r.handler) } if r.HealthMon != nil { diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 4319ff4..45968be 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -7,17 +7,20 @@ import ( "strings" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( Command struct { - raw string - CommandExecutor + raw string + exec *CommandExecutor } CommandExecutor struct { + directive string http.HandlerFunc proceed bool } @@ -34,7 +37,7 @@ const ( var commands = map[string]struct { validate ValidateFunc - build func(args any) CommandExecutor + build func(args any) *CommandExecutor }{ CommandRewrite: { validate: func(args []string) (any, E.Error) { @@ -43,17 +46,22 @@ var commands = map[string]struct { } return validateURLPaths(args) }, - build: func(args any) CommandExecutor { + build: func(args any) *CommandExecutor { a := args.([]string) orig, repl := a[0], a[1] - return CommandExecutor{ + return &CommandExecutor{ HandlerFunc: func(w http.ResponseWriter, r *http.Request) { - if len(r.URL.Path) > 0 && r.URL.Path[0] != '/' { - r.URL.Path = "/" + r.URL.Path + path := r.URL.Path + if len(path) > 0 && path[0] != '/' { + path = "/" + path } - r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1) + if !strings.HasPrefix(path, orig) { + return + } + path = repl + path[len(orig):] + r.URL.Path = path r.URL.RawPath = r.URL.EscapedPath() - r.RequestURI = r.URL.String() + r.RequestURI = r.URL.RequestURI() }, proceed: true, } @@ -61,9 +69,9 @@ var commands = map[string]struct { }, CommandServe: { validate: validateFSPath, - build: func(args any) CommandExecutor { + build: func(args any) *CommandExecutor { root := args.(string) - return CommandExecutor{ + return &CommandExecutor{ HandlerFunc: func(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) }, @@ -73,9 +81,9 @@ var commands = map[string]struct { }, CommandRedirect: { validate: validateURL, - build: func(args any) CommandExecutor { + build: func(args any) *CommandExecutor { target := args.(types.URL).String() - return CommandExecutor{ + return &CommandExecutor{ HandlerFunc: func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, target, http.StatusTemporaryRedirect) }, @@ -98,10 +106,10 @@ var commands = map[string]struct { } return []any{code, text}, nil }, - build: func(args any) CommandExecutor { + build: func(args any) *CommandExecutor { a := args.([]any) code, text := a[0].(int), a[1].(string) - return CommandExecutor{ + return &CommandExecutor{ HandlerFunc: func(w http.ResponseWriter, r *http.Request) { http.Error(w, text, code) }, @@ -111,13 +119,13 @@ var commands = map[string]struct { }, CommandProxy: { validate: validateURL, - build: func(args any) CommandExecutor { + build: func(args any) *CommandExecutor { target := args.(types.URL) if target.Scheme == "" { target.Scheme = "http" } rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) - return CommandExecutor{ + return &CommandExecutor{ HandlerFunc: rp.ServeHTTP, proceed: false, } @@ -125,34 +133,77 @@ var commands = map[string]struct { }, } +// Parse implements strutils.Parser. func (cmd *Command) Parse(v string) error { cmd.raw = v - directive, args, err := parse(v) - if err != nil { - return err - } - if directive == CommandBypass { - if len(args) != 0 { - return ErrInvalidArguments.Subject(directive) - } + lines := strutils.SplitLine(v) + if len(lines) == 0 { return nil } - builder, ok := commands[directive] - if !ok { - return ErrUnknownDirective.Subject(directive) + executors := make([]*CommandExecutor, 0, len(lines)) + for _, line := range lines { + if line == "" { + continue + } + + directive, args, err := parse(line) + if err != nil { + return err + } + + if directive == CommandBypass { + if len(args) != 0 { + return ErrInvalidArguments.Subject(directive) + } + return nil + } + + builder, ok := commands[directive] + if !ok { + return ErrUnknownDirective.Subject(directive) + } + validArgs, err := builder.validate(args) + if err != nil { + return err.Subject(directive) + } + + exec := builder.build(validArgs) + exec.directive = directive + executors = append(executors, exec) } - validArgs, err := builder.validate(args) + + exec, err := buildCmd(executors) if err != nil { - return err.Subject(directive) + return err } - cmd.CommandExecutor = builder.build(validArgs) + cmd.exec = exec return nil } +func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) { + for i, exec := range executors { + if !exec.proceed && i != len(executors)-1 { + return nil, ErrInvalidCommandSequence. + Withf("%s cannot follow %s", exec, executors[i+1]) + } + } + return &CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + for _, exec := range executors { + logging.Debug(). + Str("directive", exec.directive). + Msg("executing command") + exec.HandlerFunc(w, r) + } + }, + proceed: executors[len(executors)-1].proceed, + }, nil +} + func (cmd *Command) isBypass() bool { - return cmd.HandlerFunc == nil + return cmd.exec == nil } func (cmd *Command) String() string { @@ -162,3 +213,7 @@ func (cmd *Command) String() string { func (cmd *Command) MarshalJSON() ([]byte, error) { return []byte("\"" + cmd.String() + "\""), nil } + +func (exec *CommandExecutor) String() string { + return exec.directive +} diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index d8484df..ad79fb5 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -3,11 +3,12 @@ package rules import E "github.com/yusing/go-proxy/internal/error" var ( - ErrUnterminatedQuotes = E.New("unterminated quotes") - ErrUnsupportedEscapeChar = E.New("unsupported escape char") - ErrUnknownDirective = E.New("unknown directive") - ErrInvalidArguments = E.New("invalid arguments") - ErrInvalidOnTarget = E.New("invalid `rule.on` target") + ErrUnterminatedQuotes = E.New("unterminated quotes") + ErrUnsupportedEscapeChar = E.New("unsupported escape char") + ErrUnknownDirective = E.New("unknown directive") + ErrInvalidArguments = E.New("invalid arguments") + ErrInvalidOnTarget = E.New("invalid `rule.on` target") + ErrInvalidCommandSequence = E.New("invalid command sequence") ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 8b8063e..7771366 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -73,6 +73,7 @@ var checkers = map[string]struct { }, } +// Parse implements strutils.Parser. func (on *RuleOn) Parse(v string) error { on.raw = v @@ -81,6 +82,9 @@ func (on *RuleOn) Parse(v string) error { errs := E.NewBuilder("rule.on syntax errors") for i, line := range lines { + if line == "" { + continue + } parsed, err := parseOn(line) if err != nil { errs.Add(err.Subjectf("line %d", i+1)) diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index 54c267b..857e530 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -3,7 +3,7 @@ package rules import ( "net/http" - "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/logging" ) type ( @@ -56,7 +56,7 @@ type ( // if no rule matches, the default rule is executed // if no rule matches and default rule is not set, // the request is passed to the upstream. -func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc { +func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { // move bypass rules to the front. bypassRules := make(Rules, 0, len(rules)) otherRules := make(Rules, 0, len(rules)) @@ -80,12 +80,15 @@ func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc bypassRules = []Rule{} } if len(otherRules) == 0 { - otherRules = []Rule{defaultRule} + otherRules = []Rule{} } return func(w http.ResponseWriter, r *http.Request) { for _, rule := range bypassRules { if rule.On.check(r) { + logging.Debug(). + Str("rule", rule.Name). + Msg("matched: bypass") up.ServeHTTP(w, r) return } @@ -93,18 +96,28 @@ func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc hasMatch := false for _, rule := range otherRules { if rule.On.check(r) { + logging.Debug(). + Str("rule", rule.Name). + Msgf("matched proceed=%t", rule.Do.exec.proceed) hasMatch = true - rule.Do.HandlerFunc(w, r) - if !rule.Do.proceed { + rule.Do.exec.HandlerFunc(w, r) + if !rule.Do.exec.proceed { return } } } if hasMatch || defaultRule.Do.isBypass() { + logging.Debug(). + Str("rule", defaultRule.Name). + Msg("matched: bypass") up.ServeHTTP(w, r) return } - defaultRule.Do.HandlerFunc(w, r) + logging.Debug(). + Str("rule", defaultRule.Name). + Msg("matched: default") + + defaultRule.Do.exec.HandlerFunc(w, r) } }