fix proxy rules behavior and implemented a few more rules and commands, dependencies upgrade

This commit is contained in:
yusing 2025-01-11 12:22:42 +08:00
parent f2df756c17
commit 0ce7f29976
20 changed files with 991 additions and 443 deletions

View file

@ -23,7 +23,7 @@ lint:
enabled: enabled:
- hadolint@2.12.1-beta - hadolint@2.12.1-beta
- actionlint@1.7.6 - actionlint@1.7.6
- checkov@3.2.350 - checkov@3.2.352
- git-diff-check - git-diff-check
- gofmt@1.20.4 - gofmt@1.20.4
- golangci-lint@1.63.4 - golangci-lint@1.63.4
@ -32,7 +32,7 @@ lint:
- prettier@3.4.2 - prettier@3.4.2
- shellcheck@0.10.0 - shellcheck@0.10.0
- shfmt@3.6.0 - shfmt@3.6.0
- trufflehog@3.88.1 - trufflehog@3.88.2
actions: actions:
disabled: disabled:
- trunk-announce - trunk-announce

View file

@ -20,6 +20,8 @@ import (
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
) )
var rawLogger = log.New(os.Stdout, "", 0)
func main() { func main() {
args := common.GetArgs() args := common.GetArgs()
@ -31,12 +33,12 @@ func main() {
if err := query.ReloadServer(); err != nil { if err := query.ReloadServer(); err != nil {
E.LogFatal("server reload error", err) E.LogFatal("server reload error", err)
} }
logging.Info().Msg("ok") rawLogger.Println("ok")
return return
case common.CommandListIcons: case common.CommandListIcons:
icons, err := internal.ListAvailableIcons() icons, err := internal.ListAvailableIcons()
if err != nil { if err != nil {
log.Fatal(err) rawLogger.Fatal(err)
} }
printJSON(icons) printJSON(icons)
return return
@ -139,6 +141,5 @@ func printJSON(obj any) {
if err != nil { if err != nil {
logging.Fatal().Err(err).Send() logging.Fatal().Err(err).Send()
} }
rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Print(string(j)) // raw output for convenience using "jq" rawLogger.Print(string(j)) // raw output for convenience using "jq"
} }

4
go.mod
View file

@ -15,6 +15,7 @@ require (
github.com/prometheus/client_golang v1.20.5 github.com/prometheus/client_golang v1.20.5
github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
golang.org/x/crypto v0.32.0
golang.org/x/net v0.34.0 golang.org/x/net v0.34.0
golang.org/x/text v0.21.0 golang.org/x/text v0.21.0
golang.org/x/time v0.9.0 golang.org/x/time v0.9.0
@ -43,7 +44,7 @@ require (
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/compress v1.17.11 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.62 // indirect github.com/miekg/dns v1.1.62 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect
@ -65,7 +66,6 @@ require (
go.opentelemetry.io/otel/metric v1.33.0 // indirect go.opentelemetry.io/otel/metric v1.33.0 // indirect
go.opentelemetry.io/otel/sdk v1.30.0 // indirect go.opentelemetry.io/otel/sdk v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.33.0 // indirect go.opentelemetry.io/otel/trace v1.33.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/mod v0.22.0 // indirect golang.org/x/mod v0.22.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.10.0 // indirect

3
go.sum
View file

@ -86,8 +86,9 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=

View file

@ -139,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error {
} }
if len(r.Raw.Rules) > 0 { if len(r.Raw.Rules) > 0 {
r.handler = r.Raw.Rules.BuildHandler(r.handler) r.handler = r.Raw.Rules.BuildHandler(r.TargetName(), r.handler)
} }
if r.HealthMon != nil { if r.HealthMon != nil {

View file

@ -0,0 +1,117 @@
package rules
import (
"net"
"net/http"
"net/url"
"sync"
)
// Cache is a map of cached values for a request.
// It prevents the same value from being parsed multiple times.
type (
Cache map[string]any
UpdateFunc[T any] func(T) T
)
const (
CacheKeyQueries = "queries"
CacheKeyCookies = "cookies"
CacheKeyRemoteIP = "remote_ip"
CacheKeyBasicAuth = "basic_auth"
)
var cacheKeys = []string{
CacheKeyQueries,
CacheKeyCookies,
CacheKeyRemoteIP,
CacheKeyBasicAuth,
}
var cachePool = &sync.Pool{
New: func() any {
return make(Cache)
},
}
// NewCache returns a new Cached.
func NewCache() Cache {
return cachePool.Get().(Cache)
}
// Release clear the contents of the Cached and returns it to the pool.
func (c Cache) Release() {
for _, k := range cacheKeys {
delete(c, k)
}
cachePool.Put(c)
}
// GetQueries returns the queries.
// If r does not have queries, an empty map is returned.
func (c Cache) GetQueries(r *http.Request) url.Values {
v, ok := c[CacheKeyQueries]
if !ok {
v = r.URL.Query()
c[CacheKeyQueries] = v
}
return v.(url.Values)
}
func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) {
queries := c.GetQueries(r)
update(queries)
r.URL.RawQuery = queries.Encode()
}
// GetCookies returns the cookies.
// If r does not have cookies, an empty slice is returned.
func (c Cache) GetCookies(r *http.Request) []*http.Cookie {
v, ok := c[CacheKeyCookies]
if !ok {
v = r.Cookies()
c[CacheKeyCookies] = v
}
return v.([]*http.Cookie)
}
func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) {
cookies := update(c.GetCookies(r))
c[CacheKeyCookies] = cookies
r.Header.Del("Cookie")
for _, cookie := range cookies {
r.AddCookie(cookie)
}
}
// GetRemoteIP returns the remote ip address.
// If r.RemoteAddr is not a valid ip address, nil is returned.
func (c Cache) GetRemoteIP(r *http.Request) net.IP {
v, ok := c[CacheKeyRemoteIP]
if !ok {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
v = net.ParseIP(host)
c[CacheKeyRemoteIP] = v
}
return v.(net.IP)
}
// GetBasicAuth returns *Credentials the basic auth username and password.
// If r does not have basic auth, nil is returned.
func (c Cache) GetBasicAuth(r *http.Request) *Credentials {
v, ok := c[CacheKeyBasicAuth]
if !ok {
u, p, ok := r.BasicAuth()
if ok {
v = &Credentials{u, []byte(p)}
c[CacheKeyBasicAuth] = v
} else {
c[CacheKeyBasicAuth] = nil
return nil
}
}
return v.(*Credentials)
}

View file

@ -0,0 +1,34 @@
package rules
import "net/http"
type (
CheckFunc func(cached Cache, r *http.Request) bool
Checker interface {
Check(cached Cache, r *http.Request) bool
}
CheckMatchSingle []Checker
CheckMatchAll []Checker
)
func (checker CheckFunc) Check(cached Cache, r *http.Request) bool {
return checker(cached, r)
}
func (checkers CheckMatchSingle) Check(cached Cache, r *http.Request) bool {
for _, check := range checkers {
if check.Check(cached, r) {
return true
}
}
return false
}
func (checkers CheckMatchAll) Check(cached Cache, r *http.Request) bool {
for _, check := range checkers {
if !check.Check(cached, r) {
return false
}
}
return true
}

View file

@ -0,0 +1,51 @@
package rules
import "net/http"
type (
CommandHandler interface {
// CommandHandler can read and modify the values
// then handle the request
// finally proceed to next command (or return) base on situation
Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
}
// StaticCommand will run then proceed to next command or reverse proxy.
StaticCommand http.HandlerFunc
// ReturningCommand will run then return immediately.
ReturningCommand http.HandlerFunc
// DynamicCommand will return base on the request
// and can raed or modify the values.
DynamicCommand func(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
// BypassCommand will skip all the following commands
// and directly return to reverse proxy.
BypassCommand struct{}
// Commands is a slice of CommandHandler.
Commands []CommandHandler
)
func (c StaticCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
c(w, r)
return true
}
func (c ReturningCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
c(w, r)
return false
}
func (c DynamicCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
return c(cached, w, r)
}
func (c BypassCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
return true
}
func (c Commands) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
for _, cmd := range c {
if !cmd.Handle(cached, w, r) {
return false
}
}
return true
}

View file

@ -0,0 +1,27 @@
package rules
import "golang.org/x/crypto/bcrypt"
type (
HashedCrendentials struct {
Username string
CheckMatch func(inputPwd []byte) bool
}
Credentials struct {
Username string
Password []byte
}
)
func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendentials {
return &HashedCrendentials{username, func(inputPwd []byte) bool {
return bcrypt.CompareHashAndPassword(hashedPassword, inputPwd) == nil
}}
}
func (hc *HashedCrendentials) Match(cred *Credentials) bool {
if cred == nil {
return false
}
return hc.Username == cred.Username && hc.CheckMatch(cred.Password)
}

View file

@ -16,28 +16,27 @@ import (
type ( type (
Command struct { Command struct {
raw string raw string
exec *CommandExecutor exec CommandHandler
}
CommandExecutor struct {
directive string
http.HandlerFunc
proceed bool
} }
) )
const ( const (
CommandRewrite = "rewrite" CommandRewrite = "rewrite"
CommandServe = "serve" CommandServe = "serve"
CommandProxy = "proxy" CommandProxy = "proxy"
CommandRedirect = "redirect" CommandRedirect = "redirect"
CommandError = "error" CommandError = "error"
CommandBypass = "bypass" CommandRequireBasicAuth = "require_basic_auth"
CommandSet = "set"
CommandAdd = "add"
CommandRemove = "remove"
CommandBypass = "bypass"
) )
var commands = map[string]struct { var commands = map[string]struct {
help Help help Help
validate ValidateFunc validate ValidateFunc
build func(args any) *CommandExecutor build func(args any) CommandHandler
}{ }{
CommandRewrite: { CommandRewrite: {
help: Help{ help: Help{
@ -53,25 +52,22 @@ var commands = map[string]struct {
} }
return validateURLPaths(args) return validateURLPaths(args)
}, },
build: func(args any) *CommandExecutor { build: func(args any) CommandHandler {
a := args.([]string) a := args.([]string)
orig, repl := a[0], a[1] orig, repl := a[0], a[1]
return &CommandExecutor{ return StaticCommand(func(w http.ResponseWriter, r *http.Request) {
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { path := r.URL.Path
path := r.URL.Path if len(path) > 0 && path[0] != '/' {
if len(path) > 0 && path[0] != '/' { path = "/" + path
path = "/" + path }
} if !strings.HasPrefix(path, orig) {
if !strings.HasPrefix(path, orig) { return
return }
} path = repl + path[len(orig):]
path = repl + path[len(orig):] r.URL.Path = path
r.URL.Path = path r.URL.RawPath = r.URL.EscapedPath()
r.URL.RawPath = r.URL.EscapedPath() r.RequestURI = r.URL.RequestURI()
r.RequestURI = r.URL.RequestURI() })
},
proceed: true,
}
}, },
}, },
CommandServe: { CommandServe: {
@ -82,14 +78,11 @@ var commands = map[string]struct {
}, },
}, },
validate: validateFSPath, validate: validateFSPath,
build: func(args any) *CommandExecutor { build: func(args any) CommandHandler {
root := args.(string) root := args.(string)
return &CommandExecutor{ return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) })
},
proceed: false,
}
}, },
}, },
CommandRedirect: { CommandRedirect: {
@ -100,14 +93,11 @@ var commands = map[string]struct {
}, },
}, },
validate: validateURL, validate: validateURL,
build: func(args any) *CommandExecutor { build: func(args any) CommandHandler {
target := args.(types.URL).String() target := args.(types.URL).String()
return &CommandExecutor{ return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, target, http.StatusTemporaryRedirect)
http.Redirect(w, r, target, http.StatusTemporaryRedirect) })
},
proceed: false,
}
}, },
}, },
CommandError: { CommandError: {
@ -130,17 +120,34 @@ var commands = map[string]struct {
if !gphttp.IsStatusCodeValid(code) { if !gphttp.IsStatusCodeValid(code) {
return nil, ErrInvalidArguments.Subject(codeStr) return nil, ErrInvalidArguments.Subject(codeStr)
} }
return []any{code, text}, nil return &Tuple[int, string]{code, text}, nil
}, },
build: func(args any) *CommandExecutor { build: func(args any) CommandHandler {
a := args.([]any) code, text := args.(*Tuple[int, string]).Unpack()
code, text := a[0].(int), a[1].(string) return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
return &CommandExecutor{ http.Error(w, text, code)
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { })
http.Error(w, text, code) },
}, },
proceed: false, CommandRequireBasicAuth: {
help: Help{
command: CommandRequireBasicAuth,
args: map[string]string{
"realm": "the authentication realm",
},
},
validate: func(args []string) (any, E.Error) {
if len(args) == 1 {
return args[0], nil
} }
return nil, ErrExpectOneArg
},
build: func(args any) CommandHandler {
realm := args.(string)
return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
}, },
}, },
CommandProxy: { CommandProxy: {
@ -151,30 +158,69 @@ var commands = map[string]struct {
}, },
}, },
validate: validateAbsoluteURL, validate: validateAbsoluteURL,
build: func(args any) *CommandExecutor { build: func(args any) CommandHandler {
target := args.(types.URL) target := args.(types.URL)
if target.Scheme == "" { if target.Scheme == "" {
target.Scheme = "http" target.Scheme = "http"
} }
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
return &CommandExecutor{ return ReturningCommand(rp.ServeHTTP)
HandlerFunc: rp.ServeHTTP, },
proceed: false, },
} CommandSet: {
help: Help{
command: CommandSet,
args: map[string]string{
"field": "the field to set",
"value": "the value to set",
},
},
validate: func(args []string) (any, E.Error) {
return validateModField(ModFieldSet, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
},
},
CommandAdd: {
help: Help{
command: CommandAdd,
args: map[string]string{
"field": "the field to add",
"value": "the value to add",
},
},
validate: func(args []string) (any, E.Error) {
return validateModField(ModFieldAdd, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
},
},
CommandRemove: {
help: Help{
command: CommandRemove,
args: map[string]string{
"field": "the field to remove",
},
},
validate: func(args []string) (any, E.Error) {
return validateModField(ModFieldRemove, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
}, },
}, },
} }
// Parse implements strutils.Parser. // Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error { func (cmd *Command) Parse(v string) error {
cmd.raw = v
lines := strutils.SplitLine(v) lines := strutils.SplitLine(v)
if len(lines) == 0 { if len(lines) == 0 {
return nil return nil
} }
executors := make([]*CommandExecutor, 0, len(lines)) executors := make([]CommandHandler, 0, len(lines))
for _, line := range lines { for _, line := range lines {
if line == "" { if line == "" {
continue continue
@ -189,7 +235,7 @@ func (cmd *Command) Parse(v string) error {
if len(args) != 0 { if len(args) != 0 {
return ErrInvalidArguments.Subject(directive) return ErrInvalidArguments.Subject(directive)
} }
executors = append(executors, nil) executors = append(executors, BypassCommand{})
continue continue
} }
@ -202,48 +248,58 @@ func (cmd *Command) Parse(v string) error {
return err.Subject(directive).Withf("%s", builder.help.String()) return err.Subject(directive).Withf("%s", builder.help.String())
} }
exec := builder.build(validArgs) executors = append(executors, builder.build(validArgs))
exec.directive = directive }
executors = append(executors, exec)
if len(executors) == 0 {
return nil
} }
exec, err := buildCmd(executors) exec, err := buildCmd(executors)
if err != nil { if err != nil {
return err return err
} }
cmd.raw = v
cmd.exec = exec cmd.exec = exec
return nil return nil
} }
func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) { func buildCmd(executors []CommandHandler) (CommandHandler, error) {
for i, exec := range executors { for i, exec := range executors {
if !exec.proceed && i != len(executors)-1 { switch exec.(type) {
return nil, ErrInvalidCommandSequence. case ReturningCommand, BypassCommand:
Withf("%s cannot follow %s", exec, executors[i+1]) if i != len(executors)-1 {
return nil, ErrInvalidCommandSequence.
Withf("a returning / bypass command must be the last command")
}
} }
} }
return &CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) { return Commands(executors), nil
for _, exec := range executors {
exec.HandlerFunc(w, r)
}
},
proceed: executors[len(executors)-1].proceed,
}, nil
} }
// Command is purely "bypass" or empty.
func (cmd *Command) isBypass() bool { func (cmd *Command) isBypass() bool {
return cmd.exec == nil if cmd == nil {
return true
}
switch cmd := cmd.exec.(type) {
case BypassCommand:
return true
case Commands:
// bypass command is always the last one
_, ok := cmd[len(cmd)-1].(BypassCommand)
return ok
default:
return false
}
} }
func (cmd *Command) String() string { func (cmd *Command) String() string {
return cmd.raw return cmd.raw
} }
func (cmd *Command) MarshalJSON() ([]byte, error) { func (cmd *Command) MarshalText() ([]byte, error) {
return []byte("\"" + cmd.String() + "\""), nil return []byte(cmd.String()), nil
}
func (exec *CommandExecutor) String() string {
return exec.directive
} }

View file

@ -0,0 +1,140 @@
package rules
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseCommands(t *testing.T) {
tests := []struct {
name string
input string
wantErr error
}{
// bypass tests
{
name: "bypass_valid",
input: "bypass",
wantErr: nil,
},
{
name: "bypass_invalid_with_args",
input: "bypass /",
wantErr: ErrInvalidArguments,
},
// rewrite tests
{
name: "rewrite_valid",
input: "rewrite / /foo/bar",
wantErr: nil,
},
{
name: "rewrite_missing_target",
input: "rewrite /",
wantErr: ErrInvalidArguments,
},
{
name: "rewrite_too_many_args",
input: "rewrite / / /",
wantErr: ErrInvalidArguments,
},
{
name: "rewrite_no_leading_slash",
input: "rewrite abc /",
wantErr: ErrInvalidArguments,
},
// serve tests
{
name: "serve_valid",
input: "serve /var/www",
wantErr: nil,
},
{
name: "serve_missing_path",
input: "serve ",
wantErr: ErrInvalidArguments,
},
{
name: "serve_too_many_args",
input: "serve / / /",
wantErr: ErrInvalidArguments,
},
// redirect tests
{
name: "redirect_valid",
input: "redirect /",
wantErr: nil,
},
{
name: "redirect_too_many_args",
input: "redirect / /",
wantErr: ErrInvalidArguments,
},
// error directive tests
{
name: "error_valid",
input: "error 404 Not\\ Found",
wantErr: nil,
},
{
name: "error_missing_status_code",
input: "error Not\\ Found",
wantErr: ErrInvalidArguments,
},
{
name: "error_too_many_args",
input: "error 404 Not\\ Found extra",
wantErr: ErrInvalidArguments,
},
{
name: "error_no_escaped_space",
input: "error 404 Not Found",
wantErr: ErrInvalidArguments,
},
{
name: "error_invalid_status_code",
input: "error 123 abc",
wantErr: ErrInvalidArguments,
},
// proxy directive tests
{
name: "proxy_valid",
input: "proxy http://localhost:8080",
wantErr: nil,
},
{
name: "proxy_missing_target",
input: "proxy",
wantErr: ErrInvalidArguments,
},
{
name: "proxy_too_many_args",
input: "proxy http://localhost:8080 extra",
wantErr: ErrInvalidArguments,
},
{
name: "proxy_invalid_url",
input: "proxy invalid_url",
wantErr: ErrInvalidArguments,
},
// unknown directive test
{
name: "unknown_directive",
input: "unknown /",
wantErr: ErrUnknownDirective,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := Command{}
err := cmd.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}

View file

@ -9,7 +9,9 @@ var (
ErrInvalidArguments = E.New("invalid arguments") ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidOnTarget = E.New("invalid `rule.on` target") ErrInvalidOnTarget = E.New("invalid `rule.on` target")
ErrInvalidCommandSequence = E.New("invalid command sequence") ErrInvalidCommandSequence = E.New("invalid command sequence")
ErrInvalidSetTarget = E.New("invalid `rule.set` target")
ErrExpectNoArg = ErrInvalidArguments.Withf("expect no arg")
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")
) )

View file

@ -0,0 +1,142 @@
package rules
import (
"net/http"
"net/url"
)
type (
FieldHandler struct {
set, add, remove CommandHandler
}
FieldModifier string
)
const (
ModFieldSet FieldModifier = "set"
ModFieldAdd FieldModifier = "add"
ModFieldRemove FieldModifier = "remove"
)
const (
FieldHeader = "header"
FieldQuery = "query"
FieldCookie = "cookie"
)
var modFields = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) *FieldHandler
}{
FieldHeader: {
help: Help{
command: FieldHeader,
args: map[string]string{
"key": "the header key",
"value": "the header value",
},
},
validate: toStrTuple,
builder: func(args any) *FieldHandler {
k, v := args.(*StrTuple).Unpack()
return &FieldHandler{
set: StaticCommand(func(w http.ResponseWriter, r *http.Request) {
w.Header()[k] = []string{v}
}),
add: StaticCommand(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h[k] = append(h[k], v)
}),
remove: StaticCommand(func(w http.ResponseWriter, r *http.Request) {
delete(w.Header(), k)
}),
}
},
},
FieldQuery: {
help: Help{
command: FieldQuery,
args: map[string]string{
"key": "the query key",
"value": "the query value",
},
},
validate: toStrTuple,
builder: func(args any) *FieldHandler {
k, v := args.(*StrTuple).Unpack()
return &FieldHandler{
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateQueries(r, func(queries url.Values) {
queries.Set(k, v)
})
return true
}),
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateQueries(r, func(queries url.Values) {
queries.Add(k, v)
})
return true
}),
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateQueries(r, func(queries url.Values) {
queries.Del(k)
})
return true
}),
}
},
},
FieldCookie: {
help: Help{
command: FieldCookie,
args: map[string]string{
"key": "the cookie key",
"value": "the cookie value",
},
},
validate: toStrTuple,
builder: func(args any) *FieldHandler {
k, v := args.(*StrTuple).Unpack()
return &FieldHandler{
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
for i, c := range cookies {
if c.Name == k {
cookies[i].Value = v
return cookies
}
}
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return true
}),
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return true
}),
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
index := -1
for i, c := range cookies {
if c.Name == k {
index = i
break
}
}
if index != -1 {
if len(cookies) == 1 {
return []*http.Cookie{}
}
return append(cookies[:index], cookies[index+1:]...)
}
return cookies
})
return true
}),
}
},
},
}

View file

@ -1,37 +1,34 @@
package rules package rules
import ( import (
"net"
"net/http" "net/http"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type RuleOn struct {
RuleOn struct { raw string
raw string checker Checker
check CheckFulfill }
}
CheckFulfill func(r *http.Request) bool
Checkers []CheckFulfill
)
const ( const (
OnHeader = "header" OnHeader = "header"
OnQuery = "query" OnQuery = "query"
OnCookie = "cookie" OnCookie = "cookie"
OnForm = "form" OnForm = "form"
OnPostForm = "postform" OnPostForm = "postform"
OnMethod = "method" OnMethod = "method"
OnPath = "path" OnPath = "path"
OnRemote = "remote" OnRemote = "remote"
OnBasicAuth = "basic_auth"
) )
var checkers = map[string]struct { var checkers = map[string]struct {
help Help help Help
validate ValidateFunc validate ValidateFunc
check func(r *http.Request, args any) bool builder func(args any) CheckFunc
}{ }{
OnHeader: { OnHeader: {
help: Help{ help: Help{
@ -42,8 +39,11 @@ var checkers = map[string]struct {
}, },
}, },
validate: toStrTuple, validate: toStrTuple,
check: func(r *http.Request, args any) bool { builder: func(args any) CheckFunc {
return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second k, v := args.(*StrTuple).Unpack()
return func(cached Cache, r *http.Request) bool {
return r.Header.Get(k) == v
}
}, },
}, },
OnQuery: { OnQuery: {
@ -55,8 +55,17 @@ var checkers = map[string]struct {
}, },
}, },
validate: toStrTuple, validate: toStrTuple,
check: func(r *http.Request, args any) bool { builder: func(args any) CheckFunc {
return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second k, v := args.(*StrTuple).Unpack()
return func(cached Cache, r *http.Request) bool {
queries := cached.GetQueries(r)[k]
for _, query := range queries {
if query == v {
return true
}
}
return false
}
}, },
}, },
OnCookie: { OnCookie: {
@ -68,14 +77,18 @@ var checkers = map[string]struct {
}, },
}, },
validate: toStrTuple, validate: toStrTuple,
check: func(r *http.Request, args any) bool { builder: func(args any) CheckFunc {
cookies := r.CookiesNamed(args.(StrTuple).First) k, v := args.(*StrTuple).Unpack()
for _, cookie := range cookies { return func(cached Cache, r *http.Request) bool {
if cookie.Value == args.(StrTuple).Second { cookies := cached.GetCookies(r)
return true for _, cookie := range cookies {
if cookie.Name == k &&
cookie.Value == v {
return true
}
} }
return false
} }
return false
}, },
}, },
OnForm: { OnForm: {
@ -87,8 +100,11 @@ var checkers = map[string]struct {
}, },
}, },
validate: toStrTuple, validate: toStrTuple,
check: func(r *http.Request, args any) bool { builder: func(args any) CheckFunc {
return r.FormValue(args.(StrTuple).First) == args.(StrTuple).Second k, v := args.(*StrTuple).Unpack()
return func(cached Cache, r *http.Request) bool {
return r.FormValue(k) == v
}
}, },
}, },
OnPostForm: { OnPostForm: {
@ -100,8 +116,11 @@ var checkers = map[string]struct {
}, },
}, },
validate: toStrTuple, validate: toStrTuple,
check: func(r *http.Request, args any) bool { builder: func(args any) CheckFunc {
return r.PostFormValue(args.(StrTuple).First) == args.(StrTuple).Second k, v := args.(*StrTuple).Unpack()
return func(cached Cache, r *http.Request) bool {
return r.PostFormValue(k) == v
}
}, },
}, },
OnMethod: { OnMethod: {
@ -112,8 +131,11 @@ var checkers = map[string]struct {
}, },
}, },
validate: validateMethod, validate: validateMethod,
check: func(r *http.Request, method any) bool { builder: func(args any) CheckFunc {
return r.Method == method.(string) method := args.(string)
return func(cached Cache, r *http.Request) bool {
return r.Method == method
}
}, },
}, },
OnPath: { OnPath: {
@ -127,12 +149,15 @@ var checkers = map[string]struct {
}, },
}, },
validate: validateURLPath, validate: validateURLPath,
check: func(r *http.Request, globPath any) bool { builder: func(args any) CheckFunc {
reqPath := r.URL.Path pat := args.(string)
if len(reqPath) > 0 && reqPath[0] != '/' { return func(cached Cache, r *http.Request) bool {
reqPath = "/" + reqPath reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
}
return strutils.GlobMatch(pat, reqPath)
} }
return strutils.GlobMatch(globPath.(string), reqPath)
}, },
}, },
OnRemote: { OnRemote: {
@ -143,16 +168,31 @@ var checkers = map[string]struct {
}, },
}, },
validate: validateCIDR, validate: validateCIDR,
check: func(r *http.Request, cidr any) bool { builder: func(args any) CheckFunc {
host, _, err := net.SplitHostPort(r.RemoteAddr) cidr := args.(types.CIDR)
if err != nil { return func(cached Cache, r *http.Request) bool {
host = r.RemoteAddr ip := cached.GetRemoteIP(r)
if ip == nil {
return false
}
return cidr.Contains(ip)
} }
ip := net.ParseIP(host) },
if ip == nil { },
return false OnBasicAuth: {
help: Help{
command: OnBasicAuth,
args: map[string]string{
"username": "the username",
"password": "the password encrypted with bcrypt",
},
},
validate: validateUserBCryptPassword,
builder: func(args any) CheckFunc {
cred := args.(*HashedCrendentials)
return func(cached Cache, r *http.Request) bool {
return cred.Match(cached.GetBasicAuth(r))
} }
return cidr.(*net.IPNet).Contains(ip)
}, },
}, },
} }
@ -162,7 +202,7 @@ func (on *RuleOn) Parse(v string) error {
on.raw = v on.raw = v
lines := strutils.SplitLine(v) lines := strutils.SplitLine(v)
checks := make(Checkers, 0, len(lines)) checkAnd := make(CheckMatchAll, 0, len(lines))
errs := E.NewBuilder("rule.on syntax errors") errs := E.NewBuilder("rule.on syntax errors")
for i, line := range lines { for i, line := range lines {
@ -174,10 +214,10 @@ func (on *RuleOn) Parse(v string) error {
errs.Add(err.Subjectf("line %d", i+1)) errs.Add(err.Subjectf("line %d", i+1))
continue continue
} }
checks = append(checks, parsed.matchOne()) checkAnd = append(checkAnd, parsed)
} }
on.check = checks.matchAll() on.checker = checkAnd
return errs.Error() return errs.Error()
} }
@ -185,28 +225,28 @@ func (on *RuleOn) String() string {
return on.raw return on.raw
} }
func (on *RuleOn) MarshalJSON() ([]byte, error) { func (on *RuleOn) MarshalText() ([]byte, error) {
return []byte("\"" + on.String() + "\""), nil return []byte(on.String()), nil
} }
func parseOn(line string) (Checkers, E.Error) { func parseOn(line string) (Checker, E.Error) {
ors := strutils.SplitRune(line, '|') ors := strutils.SplitRune(line, '|')
if len(ors) > 1 { if len(ors) > 1 {
errs := E.NewBuilder("rule.on syntax errors") errs := E.NewBuilder("rule.on syntax errors")
checks := make([]CheckFulfill, len(ors)) checkOr := make(CheckMatchSingle, len(ors))
for i, or := range ors { for i, or := range ors {
curCheckers, err := parseOn(or) curCheckers, err := parseOn(or)
if err != nil { if err != nil {
errs.Add(err) errs.Add(err)
continue continue
} }
checks[i] = curCheckers[0] checkOr[i] = curCheckers.(CheckFunc)
} }
if err := errs.Error(); err != nil { if err := errs.Error(); err != nil {
return nil, err return nil, err
} }
return checks, nil return checkOr, nil
} }
subject, args, err := parse(line) subject, args, err := parse(line)
@ -224,31 +264,5 @@ func parseOn(line string) (Checkers, E.Error) {
return nil, err.Subject(subject).Withf("%s", checker.help.String()) return nil, err.Subject(subject).Withf("%s", checker.help.String())
} }
return Checkers{ return checker.builder(validArgs), nil
func(r *http.Request) bool {
return checker.check(r, validArgs)
},
}, nil
}
func (checkers Checkers) matchOne() CheckFulfill {
return func(r *http.Request) bool {
for _, checker := range checkers {
if checker(r) {
return true
}
}
return false
}
}
func (checkers Checkers) matchAll() CheckFulfill {
return func(r *http.Request) bool {
for _, checker := range checkers {
if !checker(r) {
return false
}
}
return true
}
} }

View file

@ -0,0 +1,89 @@
package rules
import (
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr E.Error
}{
// header
{
name: "header_valid",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_invalid",
input: "header Connection",
wantErr: ErrInvalidArguments,
},
// query
{
name: "query_valid",
input: "query key value",
wantErr: nil,
},
{
name: "query_invalid",
input: "query key",
wantErr: ErrInvalidArguments,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method",
wantErr: ErrInvalidArguments,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_invalid",
input: "path",
wantErr: ErrInvalidArguments,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote",
wantErr: ErrInvalidArguments,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}

View file

@ -14,6 +14,7 @@ var escapedChars = map[rune]rune{
'\'': '\'', '\'': '\'',
'"': '"', '"': '"',
'\\': '\\', '\\': '\\',
'$': '$',
' ': ' ', ' ': ' ',
} }

View file

@ -94,3 +94,13 @@ func TestParser(t *testing.T) {
} }
}) })
} }
func BenchmarkParser(b *testing.B) {
const input = `error 403 "Forbidden "foo" "bar""\ baz`
for range b.N {
_, _, err := parse(input)
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -1,6 +1,7 @@
package rules package rules
import ( import (
"encoding/json"
"net/http" "net/http"
) )
@ -26,7 +27,7 @@ type (
on: method POST | method PUT on: method POST | method PUT
do: error 403 Forbidden do: error 403 Forbidden
*/ */
Rules []Rule Rules []*Rule
/* /*
Rule is a rule for a reverse proxy. Rule is a rule for a reverse proxy.
It do `Do` when `On` matches. It do `Do` when `On` matches.
@ -52,52 +53,72 @@ type (
// if no rule matches, the default rule is executed // if no rule matches, the default rule is executed
// if no rule matches and default rule is not set, // if no rule matches and default rule is not set,
// the request is passed to the upstream. // the request is passed to the upstream.
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { func (rules Rules) BuildHandler(caller string, up http.Handler) http.HandlerFunc {
var ( var defaultRule *Rule
defaultRule Rule
defaultRuleIndex int
)
nonDefaultRules := make(Rules, 0, len(rules))
for i, rule := range rules { for i, rule := range rules {
if rule.Name == "default" { if rule.Name == "default" {
defaultRule = rule defaultRule = rule
defaultRuleIndex = i nonDefaultRules = append(nonDefaultRules, rules[:i]...)
nonDefaultRules = append(nonDefaultRules, rules[i+1:]...)
break break
} }
} }
rules = append(rules[:defaultRuleIndex], rules[defaultRuleIndex+1:]...)
// free allocated empty slices
// before encapsulating them into the handlerFunc.
if len(rules) == 0 { if len(rules) == 0 {
if defaultRule.Do.isBypass() { if defaultRule.Do.isBypass() {
return up.ServeHTTP return up.ServeHTTP
} }
rules = []Rule{} return func(w http.ResponseWriter, r *http.Request) {
cache := NewCache()
defer cache.Release()
if defaultRule.Do.exec.Handle(cache, w, r) {
up.ServeHTTP(w, r)
}
}
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
hasMatch := false cache := NewCache()
for _, rule := range rules { defer cache.Release()
if rule.On.check(r) {
for _, rule := range nonDefaultRules {
if rule.Check(cache, r) {
if rule.Do.isBypass() { if rule.Do.isBypass() {
up.ServeHTTP(w, r) up.ServeHTTP(w, r)
return return
} }
rule.Do.exec.HandlerFunc(w, r) if !rule.Handle(cache, w, r) {
if !rule.Do.exec.proceed {
return return
} }
hasMatch = true
} }
} }
if hasMatch || defaultRule.Do.isBypass() { // bypass or proceed
if defaultRule.Do.isBypass() || defaultRule.Handle(cache, w, r) {
up.ServeHTTP(w, r) up.ServeHTTP(w, r)
return
} }
defaultRule.Do.exec.HandlerFunc(w, r)
} }
} }
func (rules Rules) MarshalJSON() ([]byte, error) {
names := make([]string, len(rules))
for i, rule := range rules {
names[i] = rule.Name
}
return json.Marshal(names)
}
func (rule *Rule) String() string {
return rule.Name
}
func (rule *Rule) Check(cached Cache, r *http.Request) bool {
return rule.On.checker.Check(cached, r)
}
func (rule *Rule) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
proceed = rule.Do.exec.Handle(cached, w, r)
return
}

View file

@ -3,249 +3,44 @@ package rules
import ( import (
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestParseSubjectArgs(t *testing.T) {
t.Run("basic", func(t *testing.T) {
subject, args, err := parse("rewrite / /foo/bar")
ExpectNoError(t, err)
ExpectEqual(t, subject, "rewrite")
ExpectDeepEqual(t, args, []string{"/", "/foo/bar"})
})
t.Run("with quotes", func(t *testing.T) {
subject, args, err := parse(`error 403 "Forbidden 'foo' 'bar'."`)
ExpectNoError(t, err)
ExpectEqual(t, subject, "error")
ExpectDeepEqual(t, args, []string{"403", "Forbidden 'foo' 'bar'."})
})
t.Run("with escaped", func(t *testing.T) {
subject, args, err := parse(`error 403 Forbidden\ \"foo\"\ \"bar\".`)
ExpectNoError(t, err)
ExpectEqual(t, subject, "error")
ExpectDeepEqual(t, args, []string{"403", "Forbidden \"foo\" \"bar\"."})
})
}
func TestParseCommands(t *testing.T) {
tests := []struct {
name string
input string
wantErr error
}{
// bypass tests
{
name: "bypass_valid",
input: "bypass",
wantErr: nil,
},
{
name: "bypass_invalid_with_args",
input: "bypass /",
wantErr: ErrInvalidArguments,
},
// rewrite tests
{
name: "rewrite_valid",
input: "rewrite / /foo/bar",
wantErr: nil,
},
{
name: "rewrite_missing_target",
input: "rewrite /",
wantErr: ErrInvalidArguments,
},
{
name: "rewrite_too_many_args",
input: "rewrite / / /",
wantErr: ErrInvalidArguments,
},
{
name: "rewrite_no_leading_slash",
input: "rewrite abc /",
wantErr: ErrInvalidArguments,
},
// serve tests
{
name: "serve_valid",
input: "serve /var/www",
wantErr: nil,
},
{
name: "serve_missing_path",
input: "serve ",
wantErr: ErrInvalidArguments,
},
{
name: "serve_too_many_args",
input: "serve / / /",
wantErr: ErrInvalidArguments,
},
// redirect tests
{
name: "redirect_valid",
input: "redirect /",
wantErr: nil,
},
{
name: "redirect_too_many_args",
input: "redirect / /",
wantErr: ErrInvalidArguments,
},
// error directive tests
{
name: "error_valid",
input: "error 404 Not\\ Found",
wantErr: nil,
},
{
name: "error_missing_status_code",
input: "error Not\\ Found",
wantErr: ErrInvalidArguments,
},
{
name: "error_too_many_args",
input: "error 404 Not\\ Found extra",
wantErr: ErrInvalidArguments,
},
{
name: "error_no_escaped_space",
input: "error 404 Not Found",
wantErr: ErrInvalidArguments,
},
{
name: "error_invalid_status_code",
input: "error 123 abc",
wantErr: ErrInvalidArguments,
},
// proxy directive tests
{
name: "proxy_valid",
input: "proxy localhost:8080",
wantErr: nil,
},
{
name: "proxy_missing_target",
input: "proxy",
wantErr: ErrInvalidArguments,
},
{
name: "proxy_too_many_args",
input: "proxy localhost:8080 extra",
wantErr: ErrInvalidArguments,
},
{
name: "proxy_invalid_url",
input: "proxy :invalid_url",
wantErr: ErrInvalidArguments,
},
// unknown directive test
{
name: "unknown_directive",
input: "unknown /",
wantErr: ErrUnknownDirective,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := Command{}
err := cmd.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr E.Error
}{
// header
{
name: "header_valid",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_invalid",
input: "header Connection",
wantErr: ErrInvalidArguments,
},
// query
{
name: "query_valid",
input: "query key value",
wantErr: nil,
},
{
name: "query_invalid",
input: "query key",
wantErr: ErrInvalidArguments,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method",
wantErr: ErrInvalidArguments,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_invalid",
input: "path",
wantErr: ErrInvalidArguments,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote",
wantErr: ErrInvalidArguments,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
func TestParseRule(t *testing.T) { func TestParseRule(t *testing.T) {
// test := map[string]any{ test := []map[string]any{
// "name": "test", {
// "on": "method GET", "name": "test",
// "do": "bypass", "on": "method POST",
// } "do": "error 403 Forbidden",
},
{
"name": "auth",
"on": `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`,
"do": "bypass",
},
{
"name": "default",
"do": "require_basic_auth any_realm",
},
}
var rules struct {
Rules Rules
}
err := utils.Deserialize(utils.SerializedObject{"rules": test}, &rules)
ExpectNoError(t, err)
ExpectEqual(t, len(rules.Rules), len(test))
ExpectEqual(t, rules.Rules[0].Name, "test")
ExpectEqual(t, rules.Rules[0].On.String(), "method POST")
ExpectEqual(t, rules.Rules[0].Do.String(), "error 403 Forbidden")
ExpectEqual(t, rules.Rules[1].Name, "auth")
ExpectEqual(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`)
ExpectEqual(t, rules.Rules[1].Do.String(), "bypass")
ExpectEqual(t, rules.Rules[2].Name, "default")
ExpectEqual(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm")
} }
// TODO: real tests.

View file

@ -1,6 +1,7 @@
package rules package rules
import ( import (
"fmt"
"os" "os"
"path" "path"
"strings" "strings"
@ -11,19 +12,31 @@ import (
) )
type ( type (
ValidateFunc func(args []string) (any, E.Error) ValidateFunc func(args []string) (any, E.Error)
StrTuple struct { Tuple[T1, T2 any] struct {
First, Second string First T1
Second T2
} }
StrTuple = Tuple[string, string]
) )
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
return t.First, t.Second
}
func (t *Tuple[T1, T2]) String() string {
return fmt.Sprintf("%v:%v", t.First, t.Second)
}
// toStrTuple returns *StrTuple.
func toStrTuple(args []string) (any, E.Error) { func toStrTuple(args []string) (any, E.Error) {
if len(args) != 2 { if len(args) != 2 {
return nil, ErrExpectTwoArgs return nil, ErrExpectTwoArgs
} }
return StrTuple{args[0], args[1]}, nil return &StrTuple{args[0], args[1]}, nil
} }
// validateURL returns types.URL with the URL validated.
func validateURL(args []string) (any, E.Error) { func validateURL(args []string) (any, E.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
@ -35,6 +48,7 @@ func validateURL(args []string) (any, E.Error) {
return u, nil return u, nil
} }
// validateAbsoluteURL returns types.URL with the URL validated.
func validateAbsoluteURL(args []string) (any, E.Error) { func validateAbsoluteURL(args []string) (any, E.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
@ -52,6 +66,7 @@ func validateAbsoluteURL(args []string) (any, E.Error) {
return u, nil return u, nil
} }
// validateCIDR returns types.CIDR with the CIDR validated.
func validateCIDR(args []string) (any, E.Error) { func validateCIDR(args []string) (any, E.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
@ -66,6 +81,7 @@ func validateCIDR(args []string) (any, E.Error) {
return cidr, nil return cidr, nil
} }
// validateURLPath returns string with the path validated.
func validateURLPath(args []string) (any, E.Error) { func validateURLPath(args []string) (any, E.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
@ -86,6 +102,7 @@ func validateURLPath(args []string) (any, E.Error) {
return p, nil return p, nil
} }
// validateURLPaths returns []string with each element validated.
func validateURLPaths(paths []string) (any, E.Error) { func validateURLPaths(paths []string) (any, E.Error) {
errs := E.NewBuilder("invalid url paths") errs := E.NewBuilder("invalid url paths")
for i, p := range paths { for i, p := range paths {
@ -102,6 +119,7 @@ func validateURLPaths(paths []string) (any, E.Error) {
return paths, nil return paths, nil
} }
// validateFSPath returns string with the path validated.
func validateFSPath(args []string) (any, E.Error) { func validateFSPath(args []string) (any, E.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
@ -113,6 +131,7 @@ func validateFSPath(args []string) (any, E.Error) {
return p, nil return p, nil
} }
// validateMethod returns string with the method validated.
func validateMethod(args []string) (any, E.Error) { func validateMethod(args []string) (any, E.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
@ -123,3 +142,31 @@ func validateMethod(args []string) (any, E.Error) {
} }
return method, nil return method, nil
} }
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
func validateUserBCryptPassword(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return BCryptCrendentials(args[0], []byte(args[1])), nil
}
// validateModField returns CommandHandler with the field validated.
func validateModField(mod FieldModifier, args []string) (CommandHandler, E.Error) {
setField, ok := modFields[args[0]]
if !ok {
return nil, ErrInvalidSetTarget.Subject(args[0])
}
validArgs, err := setField.validate(args[1:])
if err != nil {
return nil, err.Withf(setField.help.String())
}
modder := setField.builder(validArgs)
switch mod {
case ModFieldAdd:
return modder.add, nil
case ModFieldRemove:
return modder.remove, nil
}
return modder.set, nil
}