updated route rules implementation

This commit is contained in:
yusing 2025-01-09 04:27:02 +08:00
parent f906e04581
commit 74828943a6
5 changed files with 117 additions and 44 deletions

View file

@ -139,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error {
}
if len(r.Raw.Rules) > 0 {
r.handler = r.Raw.Rules.BuildHandler(r.rp)
r.handler = r.Raw.Rules.BuildHandler(r.handler)
}
if r.HealthMon != nil {

View file

@ -7,17 +7,20 @@ import (
"strings"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
Command struct {
raw string
CommandExecutor
raw string
exec *CommandExecutor
}
CommandExecutor struct {
directive string
http.HandlerFunc
proceed bool
}
@ -34,7 +37,7 @@ const (
var commands = map[string]struct {
validate ValidateFunc
build func(args any) CommandExecutor
build func(args any) *CommandExecutor
}{
CommandRewrite: {
validate: func(args []string) (any, E.Error) {
@ -43,17 +46,22 @@ var commands = map[string]struct {
}
return validateURLPaths(args)
},
build: func(args any) CommandExecutor {
build: func(args any) *CommandExecutor {
a := args.([]string)
orig, repl := a[0], a[1]
return CommandExecutor{
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
if len(r.URL.Path) > 0 && r.URL.Path[0] != '/' {
r.URL.Path = "/" + r.URL.Path
path := r.URL.Path
if len(path) > 0 && path[0] != '/' {
path = "/" + path
}
r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1)
if !strings.HasPrefix(path, orig) {
return
}
path = repl + path[len(orig):]
r.URL.Path = path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.String()
r.RequestURI = r.URL.RequestURI()
},
proceed: true,
}
@ -61,9 +69,9 @@ var commands = map[string]struct {
},
CommandServe: {
validate: validateFSPath,
build: func(args any) CommandExecutor {
build: func(args any) *CommandExecutor {
root := args.(string)
return CommandExecutor{
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
},
@ -73,9 +81,9 @@ var commands = map[string]struct {
},
CommandRedirect: {
validate: validateURL,
build: func(args any) CommandExecutor {
build: func(args any) *CommandExecutor {
target := args.(types.URL).String()
return CommandExecutor{
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
},
@ -98,10 +106,10 @@ var commands = map[string]struct {
}
return []any{code, text}, nil
},
build: func(args any) CommandExecutor {
build: func(args any) *CommandExecutor {
a := args.([]any)
code, text := a[0].(int), a[1].(string)
return CommandExecutor{
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, text, code)
},
@ -111,13 +119,13 @@ var commands = map[string]struct {
},
CommandProxy: {
validate: validateURL,
build: func(args any) CommandExecutor {
build: func(args any) *CommandExecutor {
target := args.(types.URL)
if target.Scheme == "" {
target.Scheme = "http"
}
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
return CommandExecutor{
return &CommandExecutor{
HandlerFunc: rp.ServeHTTP,
proceed: false,
}
@ -125,34 +133,77 @@ var commands = map[string]struct {
},
}
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {
cmd.raw = v
directive, args, err := parse(v)
if err != nil {
return err
}
if directive == CommandBypass {
if len(args) != 0 {
return ErrInvalidArguments.Subject(directive)
}
lines := strutils.SplitLine(v)
if len(lines) == 0 {
return nil
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
executors := make([]*CommandExecutor, 0, len(lines))
for _, line := range lines {
if line == "" {
continue
}
directive, args, err := parse(line)
if err != nil {
return err
}
if directive == CommandBypass {
if len(args) != 0 {
return ErrInvalidArguments.Subject(directive)
}
return nil
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
validArgs, err := builder.validate(args)
if err != nil {
return err.Subject(directive)
}
exec := builder.build(validArgs)
exec.directive = directive
executors = append(executors, exec)
}
validArgs, err := builder.validate(args)
exec, err := buildCmd(executors)
if err != nil {
return err.Subject(directive)
return err
}
cmd.CommandExecutor = builder.build(validArgs)
cmd.exec = exec
return nil
}
func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) {
for i, exec := range executors {
if !exec.proceed && i != len(executors)-1 {
return nil, ErrInvalidCommandSequence.
Withf("%s cannot follow %s", exec, executors[i+1])
}
}
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
for _, exec := range executors {
logging.Debug().
Str("directive", exec.directive).
Msg("executing command")
exec.HandlerFunc(w, r)
}
},
proceed: executors[len(executors)-1].proceed,
}, nil
}
func (cmd *Command) isBypass() bool {
return cmd.HandlerFunc == nil
return cmd.exec == nil
}
func (cmd *Command) String() string {
@ -162,3 +213,7 @@ func (cmd *Command) String() string {
func (cmd *Command) MarshalJSON() ([]byte, error) {
return []byte("\"" + cmd.String() + "\""), nil
}
func (exec *CommandExecutor) String() string {
return exec.directive
}

View file

@ -3,11 +3,12 @@ package rules
import E "github.com/yusing/go-proxy/internal/error"
var (
ErrUnterminatedQuotes = E.New("unterminated quotes")
ErrUnsupportedEscapeChar = E.New("unsupported escape char")
ErrUnknownDirective = E.New("unknown directive")
ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidOnTarget = E.New("invalid `rule.on` target")
ErrUnterminatedQuotes = E.New("unterminated quotes")
ErrUnsupportedEscapeChar = E.New("unsupported escape char")
ErrUnknownDirective = E.New("unknown directive")
ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidOnTarget = E.New("invalid `rule.on` target")
ErrInvalidCommandSequence = E.New("invalid command sequence")
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")

View file

@ -73,6 +73,7 @@ var checkers = map[string]struct {
},
}
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error {
on.raw = v
@ -81,6 +82,9 @@ func (on *RuleOn) Parse(v string) error {
errs := E.NewBuilder("rule.on syntax errors")
for i, line := range lines {
if line == "" {
continue
}
parsed, err := parseOn(line)
if err != nil {
errs.Add(err.Subjectf("line %d", i+1))

View file

@ -3,7 +3,7 @@ package rules
import (
"net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/logging"
)
type (
@ -56,7 +56,7 @@ type (
// if no rule matches, the default rule is executed
// if no rule matches and default rule is not set,
// the request is passed to the upstream.
func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc {
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
// move bypass rules to the front.
bypassRules := make(Rules, 0, len(rules))
otherRules := make(Rules, 0, len(rules))
@ -80,12 +80,15 @@ func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc
bypassRules = []Rule{}
}
if len(otherRules) == 0 {
otherRules = []Rule{defaultRule}
otherRules = []Rule{}
}
return func(w http.ResponseWriter, r *http.Request) {
for _, rule := range bypassRules {
if rule.On.check(r) {
logging.Debug().
Str("rule", rule.Name).
Msg("matched: bypass")
up.ServeHTTP(w, r)
return
}
@ -93,18 +96,28 @@ func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc
hasMatch := false
for _, rule := range otherRules {
if rule.On.check(r) {
logging.Debug().
Str("rule", rule.Name).
Msgf("matched proceed=%t", rule.Do.exec.proceed)
hasMatch = true
rule.Do.HandlerFunc(w, r)
if !rule.Do.proceed {
rule.Do.exec.HandlerFunc(w, r)
if !rule.Do.exec.proceed {
return
}
}
}
if hasMatch || defaultRule.Do.isBypass() {
logging.Debug().
Str("rule", defaultRule.Name).
Msg("matched: bypass")
up.ServeHTTP(w, r)
return
}
defaultRule.Do.HandlerFunc(w, r)
logging.Debug().
Str("rule", defaultRule.Name).
Msg("matched: default")
defaultRule.Do.exec.HandlerFunc(w, r)
}
}