updated route rules implementation

This commit is contained in:
yusing 2025-01-09 04:27:02 +08:00
parent f906e04581
commit 74828943a6
5 changed files with 117 additions and 44 deletions

View file

@ -139,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error {
} }
if len(r.Raw.Rules) > 0 { if len(r.Raw.Rules) > 0 {
r.handler = r.Raw.Rules.BuildHandler(r.rp) r.handler = r.Raw.Rules.BuildHandler(r.handler)
} }
if r.HealthMon != nil { if r.HealthMon != nil {

View file

@ -7,17 +7,20 @@ import (
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
Command struct { Command struct {
raw string raw string
CommandExecutor exec *CommandExecutor
} }
CommandExecutor struct { CommandExecutor struct {
directive string
http.HandlerFunc http.HandlerFunc
proceed bool proceed bool
} }
@ -34,7 +37,7 @@ const (
var commands = map[string]struct { var commands = map[string]struct {
validate ValidateFunc validate ValidateFunc
build func(args any) CommandExecutor build func(args any) *CommandExecutor
}{ }{
CommandRewrite: { CommandRewrite: {
validate: func(args []string) (any, E.Error) { validate: func(args []string) (any, E.Error) {
@ -43,17 +46,22 @@ var commands = map[string]struct {
} }
return validateURLPaths(args) return validateURLPaths(args)
}, },
build: func(args any) CommandExecutor { build: func(args any) *CommandExecutor {
a := args.([]string) a := args.([]string)
orig, repl := a[0], a[1] orig, repl := a[0], a[1]
return CommandExecutor{ return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
if len(r.URL.Path) > 0 && r.URL.Path[0] != '/' { path := r.URL.Path
r.URL.Path = "/" + r.URL.Path if len(path) > 0 && path[0] != '/' {
path = "/" + path
} }
r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1) if !strings.HasPrefix(path, orig) {
return
}
path = repl + path[len(orig):]
r.URL.Path = path
r.URL.RawPath = r.URL.EscapedPath() r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.String() r.RequestURI = r.URL.RequestURI()
}, },
proceed: true, proceed: true,
} }
@ -61,9 +69,9 @@ var commands = map[string]struct {
}, },
CommandServe: { CommandServe: {
validate: validateFSPath, validate: validateFSPath,
build: func(args any) CommandExecutor { build: func(args any) *CommandExecutor {
root := args.(string) root := args.(string)
return CommandExecutor{ return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
}, },
@ -73,9 +81,9 @@ var commands = map[string]struct {
}, },
CommandRedirect: { CommandRedirect: {
validate: validateURL, validate: validateURL,
build: func(args any) CommandExecutor { build: func(args any) *CommandExecutor {
target := args.(types.URL).String() target := args.(types.URL).String()
return CommandExecutor{ return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target, http.StatusTemporaryRedirect) http.Redirect(w, r, target, http.StatusTemporaryRedirect)
}, },
@ -98,10 +106,10 @@ var commands = map[string]struct {
} }
return []any{code, text}, nil return []any{code, text}, nil
}, },
build: func(args any) CommandExecutor { build: func(args any) *CommandExecutor {
a := args.([]any) a := args.([]any)
code, text := a[0].(int), a[1].(string) code, text := a[0].(int), a[1].(string)
return CommandExecutor{ return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, text, code) http.Error(w, text, code)
}, },
@ -111,13 +119,13 @@ var commands = map[string]struct {
}, },
CommandProxy: { CommandProxy: {
validate: validateURL, validate: validateURL,
build: func(args any) CommandExecutor { build: func(args any) *CommandExecutor {
target := args.(types.URL) target := args.(types.URL)
if target.Scheme == "" { if target.Scheme == "" {
target.Scheme = "http" target.Scheme = "http"
} }
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
return CommandExecutor{ return &CommandExecutor{
HandlerFunc: rp.ServeHTTP, HandlerFunc: rp.ServeHTTP,
proceed: false, proceed: false,
} }
@ -125,9 +133,22 @@ var commands = map[string]struct {
}, },
} }
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error { func (cmd *Command) Parse(v string) error {
cmd.raw = v cmd.raw = v
directive, args, err := parse(v)
lines := strutils.SplitLine(v)
if len(lines) == 0 {
return nil
}
executors := make([]*CommandExecutor, 0, len(lines))
for _, line := range lines {
if line == "" {
continue
}
directive, args, err := parse(line)
if err != nil { if err != nil {
return err return err
} }
@ -147,12 +168,42 @@ func (cmd *Command) Parse(v string) error {
if err != nil { if err != nil {
return err.Subject(directive) return err.Subject(directive)
} }
cmd.CommandExecutor = builder.build(validArgs)
exec := builder.build(validArgs)
exec.directive = directive
executors = append(executors, exec)
}
exec, err := buildCmd(executors)
if err != nil {
return err
}
cmd.exec = exec
return nil return nil
} }
func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) {
for i, exec := range executors {
if !exec.proceed && i != len(executors)-1 {
return nil, ErrInvalidCommandSequence.
Withf("%s cannot follow %s", exec, executors[i+1])
}
}
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
for _, exec := range executors {
logging.Debug().
Str("directive", exec.directive).
Msg("executing command")
exec.HandlerFunc(w, r)
}
},
proceed: executors[len(executors)-1].proceed,
}, nil
}
func (cmd *Command) isBypass() bool { func (cmd *Command) isBypass() bool {
return cmd.HandlerFunc == nil return cmd.exec == nil
} }
func (cmd *Command) String() string { func (cmd *Command) String() string {
@ -162,3 +213,7 @@ func (cmd *Command) String() string {
func (cmd *Command) MarshalJSON() ([]byte, error) { func (cmd *Command) MarshalJSON() ([]byte, error) {
return []byte("\"" + cmd.String() + "\""), nil return []byte("\"" + cmd.String() + "\""), nil
} }
func (exec *CommandExecutor) String() string {
return exec.directive
}

View file

@ -8,6 +8,7 @@ var (
ErrUnknownDirective = E.New("unknown directive") ErrUnknownDirective = E.New("unknown directive")
ErrInvalidArguments = E.New("invalid arguments") ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidOnTarget = E.New("invalid `rule.on` target") ErrInvalidOnTarget = E.New("invalid `rule.on` target")
ErrInvalidCommandSequence = E.New("invalid command sequence")
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")

View file

@ -73,6 +73,7 @@ var checkers = map[string]struct {
}, },
} }
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error { func (on *RuleOn) Parse(v string) error {
on.raw = v on.raw = v
@ -81,6 +82,9 @@ func (on *RuleOn) Parse(v string) error {
errs := E.NewBuilder("rule.on syntax errors") errs := E.NewBuilder("rule.on syntax errors")
for i, line := range lines { for i, line := range lines {
if line == "" {
continue
}
parsed, err := parseOn(line) parsed, err := parseOn(line)
if err != nil { if err != nil {
errs.Add(err.Subjectf("line %d", i+1)) errs.Add(err.Subjectf("line %d", i+1))

View file

@ -3,7 +3,7 @@ package rules
import ( import (
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/logging"
) )
type ( type (
@ -56,7 +56,7 @@ type (
// if no rule matches, the default rule is executed // if no rule matches, the default rule is executed
// if no rule matches and default rule is not set, // if no rule matches and default rule is not set,
// the request is passed to the upstream. // the request is passed to the upstream.
func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc { func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
// move bypass rules to the front. // move bypass rules to the front.
bypassRules := make(Rules, 0, len(rules)) bypassRules := make(Rules, 0, len(rules))
otherRules := make(Rules, 0, len(rules)) otherRules := make(Rules, 0, len(rules))
@ -80,12 +80,15 @@ func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc
bypassRules = []Rule{} bypassRules = []Rule{}
} }
if len(otherRules) == 0 { if len(otherRules) == 0 {
otherRules = []Rule{defaultRule} otherRules = []Rule{}
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
for _, rule := range bypassRules { for _, rule := range bypassRules {
if rule.On.check(r) { if rule.On.check(r) {
logging.Debug().
Str("rule", rule.Name).
Msg("matched: bypass")
up.ServeHTTP(w, r) up.ServeHTTP(w, r)
return return
} }
@ -93,18 +96,28 @@ func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc
hasMatch := false hasMatch := false
for _, rule := range otherRules { for _, rule := range otherRules {
if rule.On.check(r) { if rule.On.check(r) {
logging.Debug().
Str("rule", rule.Name).
Msgf("matched proceed=%t", rule.Do.exec.proceed)
hasMatch = true hasMatch = true
rule.Do.HandlerFunc(w, r) rule.Do.exec.HandlerFunc(w, r)
if !rule.Do.proceed { if !rule.Do.exec.proceed {
return return
} }
} }
} }
if hasMatch || defaultRule.Do.isBypass() { if hasMatch || defaultRule.Do.isBypass() {
logging.Debug().
Str("rule", defaultRule.Name).
Msg("matched: bypass")
up.ServeHTTP(w, r) up.ServeHTTP(w, r)
return return
} }
defaultRule.Do.HandlerFunc(w, r) logging.Debug().
Str("rule", defaultRule.Name).
Msg("matched: default")
defaultRule.Do.exec.HandlerFunc(w, r)
} }
} }