mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix proxy rules behavior and implemented a few more rules and commands, dependencies upgrade
This commit is contained in:
parent
f2df756c17
commit
0ce7f29976
20 changed files with 991 additions and 443 deletions
|
@ -23,7 +23,7 @@ lint:
|
||||||
enabled:
|
enabled:
|
||||||
- hadolint@2.12.1-beta
|
- hadolint@2.12.1-beta
|
||||||
- actionlint@1.7.6
|
- actionlint@1.7.6
|
||||||
- checkov@3.2.350
|
- checkov@3.2.352
|
||||||
- git-diff-check
|
- git-diff-check
|
||||||
- gofmt@1.20.4
|
- gofmt@1.20.4
|
||||||
- golangci-lint@1.63.4
|
- golangci-lint@1.63.4
|
||||||
|
@ -32,7 +32,7 @@ lint:
|
||||||
- prettier@3.4.2
|
- prettier@3.4.2
|
||||||
- shellcheck@0.10.0
|
- shellcheck@0.10.0
|
||||||
- shfmt@3.6.0
|
- shfmt@3.6.0
|
||||||
- trufflehog@3.88.1
|
- trufflehog@3.88.2
|
||||||
actions:
|
actions:
|
||||||
disabled:
|
disabled:
|
||||||
- trunk-announce
|
- trunk-announce
|
||||||
|
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"github.com/yusing/go-proxy/pkg"
|
"github.com/yusing/go-proxy/pkg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var rawLogger = log.New(os.Stdout, "", 0)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
args := common.GetArgs()
|
args := common.GetArgs()
|
||||||
|
|
||||||
|
@ -31,12 +33,12 @@ func main() {
|
||||||
if err := query.ReloadServer(); err != nil {
|
if err := query.ReloadServer(); err != nil {
|
||||||
E.LogFatal("server reload error", err)
|
E.LogFatal("server reload error", err)
|
||||||
}
|
}
|
||||||
logging.Info().Msg("ok")
|
rawLogger.Println("ok")
|
||||||
return
|
return
|
||||||
case common.CommandListIcons:
|
case common.CommandListIcons:
|
||||||
icons, err := internal.ListAvailableIcons()
|
icons, err := internal.ListAvailableIcons()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
rawLogger.Fatal(err)
|
||||||
}
|
}
|
||||||
printJSON(icons)
|
printJSON(icons)
|
||||||
return
|
return
|
||||||
|
@ -139,6 +141,5 @@ func printJSON(obj any) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Fatal().Err(err).Send()
|
logging.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
rawLogger := log.New(os.Stdout, "", 0)
|
|
||||||
rawLogger.Print(string(j)) // raw output for convenience using "jq"
|
rawLogger.Print(string(j)) // raw output for convenience using "jq"
|
||||||
}
|
}
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -15,6 +15,7 @@ require (
|
||||||
github.com/prometheus/client_golang v1.20.5
|
github.com/prometheus/client_golang v1.20.5
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.4.0
|
github.com/puzpuzpuz/xsync/v3 v3.4.0
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/rs/zerolog v1.33.0
|
||||||
|
golang.org/x/crypto v0.32.0
|
||||||
golang.org/x/net v0.34.0
|
golang.org/x/net v0.34.0
|
||||||
golang.org/x/text v0.21.0
|
golang.org/x/text v0.21.0
|
||||||
golang.org/x/time v0.9.0
|
golang.org/x/time v0.9.0
|
||||||
|
@ -43,7 +44,7 @@ require (
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
github.com/klauspost/compress v1.17.11 // indirect
|
github.com/klauspost/compress v1.17.11 // indirect
|
||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/miekg/dns v1.1.62 // indirect
|
github.com/miekg/dns v1.1.62 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
@ -65,7 +66,6 @@ require (
|
||||||
go.opentelemetry.io/otel/metric v1.33.0 // indirect
|
go.opentelemetry.io/otel/metric v1.33.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk v1.30.0 // indirect
|
go.opentelemetry.io/otel/sdk v1.30.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.33.0 // indirect
|
go.opentelemetry.io/otel/trace v1.33.0 // indirect
|
||||||
golang.org/x/crypto v0.32.0 // indirect
|
|
||||||
golang.org/x/mod v0.22.0 // indirect
|
golang.org/x/mod v0.22.0 // indirect
|
||||||
golang.org/x/oauth2 v0.25.0 // indirect
|
golang.org/x/oauth2 v0.25.0 // indirect
|
||||||
golang.org/x/sync v0.10.0 // indirect
|
golang.org/x/sync v0.10.0 // indirect
|
||||||
|
|
3
go.sum
3
go.sum
|
@ -86,8 +86,9 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
|
|
@ -139,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Raw.Rules) > 0 {
|
if len(r.Raw.Rules) > 0 {
|
||||||
r.handler = r.Raw.Rules.BuildHandler(r.handler)
|
r.handler = r.Raw.Rules.BuildHandler(r.TargetName(), r.handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.HealthMon != nil {
|
if r.HealthMon != nil {
|
||||||
|
|
117
internal/route/rules/cache.go
Normal file
117
internal/route/rules/cache.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache is a map of cached values for a request.
|
||||||
|
// It prevents the same value from being parsed multiple times.
|
||||||
|
type (
|
||||||
|
Cache map[string]any
|
||||||
|
UpdateFunc[T any] func(T) T
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CacheKeyQueries = "queries"
|
||||||
|
CacheKeyCookies = "cookies"
|
||||||
|
CacheKeyRemoteIP = "remote_ip"
|
||||||
|
CacheKeyBasicAuth = "basic_auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
var cacheKeys = []string{
|
||||||
|
CacheKeyQueries,
|
||||||
|
CacheKeyCookies,
|
||||||
|
CacheKeyRemoteIP,
|
||||||
|
CacheKeyBasicAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
var cachePool = &sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return make(Cache)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCache returns a new Cached.
|
||||||
|
func NewCache() Cache {
|
||||||
|
return cachePool.Get().(Cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release clear the contents of the Cached and returns it to the pool.
|
||||||
|
func (c Cache) Release() {
|
||||||
|
for _, k := range cacheKeys {
|
||||||
|
delete(c, k)
|
||||||
|
}
|
||||||
|
cachePool.Put(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQueries returns the queries.
|
||||||
|
// If r does not have queries, an empty map is returned.
|
||||||
|
func (c Cache) GetQueries(r *http.Request) url.Values {
|
||||||
|
v, ok := c[CacheKeyQueries]
|
||||||
|
if !ok {
|
||||||
|
v = r.URL.Query()
|
||||||
|
c[CacheKeyQueries] = v
|
||||||
|
}
|
||||||
|
return v.(url.Values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) {
|
||||||
|
queries := c.GetQueries(r)
|
||||||
|
update(queries)
|
||||||
|
r.URL.RawQuery = queries.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCookies returns the cookies.
|
||||||
|
// If r does not have cookies, an empty slice is returned.
|
||||||
|
func (c Cache) GetCookies(r *http.Request) []*http.Cookie {
|
||||||
|
v, ok := c[CacheKeyCookies]
|
||||||
|
if !ok {
|
||||||
|
v = r.Cookies()
|
||||||
|
c[CacheKeyCookies] = v
|
||||||
|
}
|
||||||
|
return v.([]*http.Cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) {
|
||||||
|
cookies := update(c.GetCookies(r))
|
||||||
|
c[CacheKeyCookies] = cookies
|
||||||
|
r.Header.Del("Cookie")
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
r.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRemoteIP returns the remote ip address.
|
||||||
|
// If r.RemoteAddr is not a valid ip address, nil is returned.
|
||||||
|
func (c Cache) GetRemoteIP(r *http.Request) net.IP {
|
||||||
|
v, ok := c[CacheKeyRemoteIP]
|
||||||
|
if !ok {
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
host = r.RemoteAddr
|
||||||
|
}
|
||||||
|
v = net.ParseIP(host)
|
||||||
|
c[CacheKeyRemoteIP] = v
|
||||||
|
}
|
||||||
|
return v.(net.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBasicAuth returns *Credentials the basic auth username and password.
|
||||||
|
// If r does not have basic auth, nil is returned.
|
||||||
|
func (c Cache) GetBasicAuth(r *http.Request) *Credentials {
|
||||||
|
v, ok := c[CacheKeyBasicAuth]
|
||||||
|
if !ok {
|
||||||
|
u, p, ok := r.BasicAuth()
|
||||||
|
if ok {
|
||||||
|
v = &Credentials{u, []byte(p)}
|
||||||
|
c[CacheKeyBasicAuth] = v
|
||||||
|
} else {
|
||||||
|
c[CacheKeyBasicAuth] = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v.(*Credentials)
|
||||||
|
}
|
34
internal/route/rules/check_on.go
Normal file
34
internal/route/rules/check_on.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type (
|
||||||
|
CheckFunc func(cached Cache, r *http.Request) bool
|
||||||
|
Checker interface {
|
||||||
|
Check(cached Cache, r *http.Request) bool
|
||||||
|
}
|
||||||
|
CheckMatchSingle []Checker
|
||||||
|
CheckMatchAll []Checker
|
||||||
|
)
|
||||||
|
|
||||||
|
func (checker CheckFunc) Check(cached Cache, r *http.Request) bool {
|
||||||
|
return checker(cached, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (checkers CheckMatchSingle) Check(cached Cache, r *http.Request) bool {
|
||||||
|
for _, check := range checkers {
|
||||||
|
if check.Check(cached, r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (checkers CheckMatchAll) Check(cached Cache, r *http.Request) bool {
|
||||||
|
for _, check := range checkers {
|
||||||
|
if !check.Check(cached, r) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
51
internal/route/rules/command.go
Normal file
51
internal/route/rules/command.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type (
|
||||||
|
CommandHandler interface {
|
||||||
|
// CommandHandler can read and modify the values
|
||||||
|
// then handle the request
|
||||||
|
// finally proceed to next command (or return) base on situation
|
||||||
|
Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
|
}
|
||||||
|
// StaticCommand will run then proceed to next command or reverse proxy.
|
||||||
|
StaticCommand http.HandlerFunc
|
||||||
|
// ReturningCommand will run then return immediately.
|
||||||
|
ReturningCommand http.HandlerFunc
|
||||||
|
// DynamicCommand will return base on the request
|
||||||
|
// and can raed or modify the values.
|
||||||
|
DynamicCommand func(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
|
// BypassCommand will skip all the following commands
|
||||||
|
// and directly return to reverse proxy.
|
||||||
|
BypassCommand struct{}
|
||||||
|
// Commands is a slice of CommandHandler.
|
||||||
|
Commands []CommandHandler
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c StaticCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
c(w, r)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ReturningCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
c(w, r)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c DynamicCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
return c(cached, w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c BypassCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Commands) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
for _, cmd := range c {
|
||||||
|
if !cmd.Handle(cached, w, r) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
27
internal/route/rules/crypto.go
Normal file
27
internal/route/rules/crypto.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import "golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
type (
|
||||||
|
HashedCrendentials struct {
|
||||||
|
Username string
|
||||||
|
CheckMatch func(inputPwd []byte) bool
|
||||||
|
}
|
||||||
|
Credentials struct {
|
||||||
|
Username string
|
||||||
|
Password []byte
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func BCryptCrendentials(username string, hashedPassword []byte) *HashedCrendentials {
|
||||||
|
return &HashedCrendentials{username, func(inputPwd []byte) bool {
|
||||||
|
return bcrypt.CompareHashAndPassword(hashedPassword, inputPwd) == nil
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HashedCrendentials) Match(cred *Credentials) bool {
|
||||||
|
if cred == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return hc.Username == cred.Username && hc.CheckMatch(cred.Password)
|
||||||
|
}
|
|
@ -16,28 +16,27 @@ import (
|
||||||
type (
|
type (
|
||||||
Command struct {
|
Command struct {
|
||||||
raw string
|
raw string
|
||||||
exec *CommandExecutor
|
exec CommandHandler
|
||||||
}
|
|
||||||
CommandExecutor struct {
|
|
||||||
directive string
|
|
||||||
http.HandlerFunc
|
|
||||||
proceed bool
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CommandRewrite = "rewrite"
|
CommandRewrite = "rewrite"
|
||||||
CommandServe = "serve"
|
CommandServe = "serve"
|
||||||
CommandProxy = "proxy"
|
CommandProxy = "proxy"
|
||||||
CommandRedirect = "redirect"
|
CommandRedirect = "redirect"
|
||||||
CommandError = "error"
|
CommandError = "error"
|
||||||
CommandBypass = "bypass"
|
CommandRequireBasicAuth = "require_basic_auth"
|
||||||
|
CommandSet = "set"
|
||||||
|
CommandAdd = "add"
|
||||||
|
CommandRemove = "remove"
|
||||||
|
CommandBypass = "bypass"
|
||||||
)
|
)
|
||||||
|
|
||||||
var commands = map[string]struct {
|
var commands = map[string]struct {
|
||||||
help Help
|
help Help
|
||||||
validate ValidateFunc
|
validate ValidateFunc
|
||||||
build func(args any) *CommandExecutor
|
build func(args any) CommandHandler
|
||||||
}{
|
}{
|
||||||
CommandRewrite: {
|
CommandRewrite: {
|
||||||
help: Help{
|
help: Help{
|
||||||
|
@ -53,25 +52,22 @@ var commands = map[string]struct {
|
||||||
}
|
}
|
||||||
return validateURLPaths(args)
|
return validateURLPaths(args)
|
||||||
},
|
},
|
||||||
build: func(args any) *CommandExecutor {
|
build: func(args any) CommandHandler {
|
||||||
a := args.([]string)
|
a := args.([]string)
|
||||||
orig, repl := a[0], a[1]
|
orig, repl := a[0], a[1]
|
||||||
return &CommandExecutor{
|
return StaticCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
path := r.URL.Path
|
||||||
path := r.URL.Path
|
if len(path) > 0 && path[0] != '/' {
|
||||||
if len(path) > 0 && path[0] != '/' {
|
path = "/" + path
|
||||||
path = "/" + path
|
}
|
||||||
}
|
if !strings.HasPrefix(path, orig) {
|
||||||
if !strings.HasPrefix(path, orig) {
|
return
|
||||||
return
|
}
|
||||||
}
|
path = repl + path[len(orig):]
|
||||||
path = repl + path[len(orig):]
|
r.URL.Path = path
|
||||||
r.URL.Path = path
|
r.URL.RawPath = r.URL.EscapedPath()
|
||||||
r.URL.RawPath = r.URL.EscapedPath()
|
r.RequestURI = r.URL.RequestURI()
|
||||||
r.RequestURI = r.URL.RequestURI()
|
})
|
||||||
},
|
|
||||||
proceed: true,
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandServe: {
|
CommandServe: {
|
||||||
|
@ -82,14 +78,11 @@ var commands = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateFSPath,
|
validate: validateFSPath,
|
||||||
build: func(args any) *CommandExecutor {
|
build: func(args any) CommandHandler {
|
||||||
root := args.(string)
|
root := args.(string)
|
||||||
return &CommandExecutor{
|
return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
})
|
||||||
},
|
|
||||||
proceed: false,
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandRedirect: {
|
CommandRedirect: {
|
||||||
|
@ -100,14 +93,11 @@ var commands = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateURL,
|
validate: validateURL,
|
||||||
build: func(args any) *CommandExecutor {
|
build: func(args any) CommandHandler {
|
||||||
target := args.(types.URL).String()
|
target := args.(types.URL).String()
|
||||||
return &CommandExecutor{
|
return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
})
|
||||||
},
|
|
||||||
proceed: false,
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandError: {
|
CommandError: {
|
||||||
|
@ -130,17 +120,34 @@ var commands = map[string]struct {
|
||||||
if !gphttp.IsStatusCodeValid(code) {
|
if !gphttp.IsStatusCodeValid(code) {
|
||||||
return nil, ErrInvalidArguments.Subject(codeStr)
|
return nil, ErrInvalidArguments.Subject(codeStr)
|
||||||
}
|
}
|
||||||
return []any{code, text}, nil
|
return &Tuple[int, string]{code, text}, nil
|
||||||
},
|
},
|
||||||
build: func(args any) *CommandExecutor {
|
build: func(args any) CommandHandler {
|
||||||
a := args.([]any)
|
code, text := args.(*Tuple[int, string]).Unpack()
|
||||||
code, text := a[0].(int), a[1].(string)
|
return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
return &CommandExecutor{
|
http.Error(w, text, code)
|
||||||
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
})
|
||||||
http.Error(w, text, code)
|
},
|
||||||
},
|
},
|
||||||
proceed: false,
|
CommandRequireBasicAuth: {
|
||||||
|
help: Help{
|
||||||
|
command: CommandRequireBasicAuth,
|
||||||
|
args: map[string]string{
|
||||||
|
"realm": "the authentication realm",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: func(args []string) (any, E.Error) {
|
||||||
|
if len(args) == 1 {
|
||||||
|
return args[0], nil
|
||||||
}
|
}
|
||||||
|
return nil, ErrExpectOneArg
|
||||||
|
},
|
||||||
|
build: func(args any) CommandHandler {
|
||||||
|
realm := args.(string)
|
||||||
|
return ReturningCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
})
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandProxy: {
|
CommandProxy: {
|
||||||
|
@ -151,30 +158,69 @@ var commands = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateAbsoluteURL,
|
validate: validateAbsoluteURL,
|
||||||
build: func(args any) *CommandExecutor {
|
build: func(args any) CommandHandler {
|
||||||
target := args.(types.URL)
|
target := args.(types.URL)
|
||||||
if target.Scheme == "" {
|
if target.Scheme == "" {
|
||||||
target.Scheme = "http"
|
target.Scheme = "http"
|
||||||
}
|
}
|
||||||
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
|
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
|
||||||
return &CommandExecutor{
|
return ReturningCommand(rp.ServeHTTP)
|
||||||
HandlerFunc: rp.ServeHTTP,
|
},
|
||||||
proceed: false,
|
},
|
||||||
}
|
CommandSet: {
|
||||||
|
help: Help{
|
||||||
|
command: CommandSet,
|
||||||
|
args: map[string]string{
|
||||||
|
"field": "the field to set",
|
||||||
|
"value": "the value to set",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: func(args []string) (any, E.Error) {
|
||||||
|
return validateModField(ModFieldSet, args)
|
||||||
|
},
|
||||||
|
build: func(args any) CommandHandler {
|
||||||
|
return args.(CommandHandler)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CommandAdd: {
|
||||||
|
help: Help{
|
||||||
|
command: CommandAdd,
|
||||||
|
args: map[string]string{
|
||||||
|
"field": "the field to add",
|
||||||
|
"value": "the value to add",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: func(args []string) (any, E.Error) {
|
||||||
|
return validateModField(ModFieldAdd, args)
|
||||||
|
},
|
||||||
|
build: func(args any) CommandHandler {
|
||||||
|
return args.(CommandHandler)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CommandRemove: {
|
||||||
|
help: Help{
|
||||||
|
command: CommandRemove,
|
||||||
|
args: map[string]string{
|
||||||
|
"field": "the field to remove",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: func(args []string) (any, E.Error) {
|
||||||
|
return validateModField(ModFieldRemove, args)
|
||||||
|
},
|
||||||
|
build: func(args any) CommandHandler {
|
||||||
|
return args.(CommandHandler)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse implements strutils.Parser.
|
// Parse implements strutils.Parser.
|
||||||
func (cmd *Command) Parse(v string) error {
|
func (cmd *Command) Parse(v string) error {
|
||||||
cmd.raw = v
|
|
||||||
|
|
||||||
lines := strutils.SplitLine(v)
|
lines := strutils.SplitLine(v)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
executors := make([]*CommandExecutor, 0, len(lines))
|
executors := make([]CommandHandler, 0, len(lines))
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if line == "" {
|
if line == "" {
|
||||||
continue
|
continue
|
||||||
|
@ -189,7 +235,7 @@ func (cmd *Command) Parse(v string) error {
|
||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return ErrInvalidArguments.Subject(directive)
|
return ErrInvalidArguments.Subject(directive)
|
||||||
}
|
}
|
||||||
executors = append(executors, nil)
|
executors = append(executors, BypassCommand{})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,48 +248,58 @@ func (cmd *Command) Parse(v string) error {
|
||||||
return err.Subject(directive).Withf("%s", builder.help.String())
|
return err.Subject(directive).Withf("%s", builder.help.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
exec := builder.build(validArgs)
|
executors = append(executors, builder.build(validArgs))
|
||||||
exec.directive = directive
|
}
|
||||||
executors = append(executors, exec)
|
|
||||||
|
if len(executors) == 0 {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
exec, err := buildCmd(executors)
|
exec, err := buildCmd(executors)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.raw = v
|
||||||
cmd.exec = exec
|
cmd.exec = exec
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) {
|
func buildCmd(executors []CommandHandler) (CommandHandler, error) {
|
||||||
for i, exec := range executors {
|
for i, exec := range executors {
|
||||||
if !exec.proceed && i != len(executors)-1 {
|
switch exec.(type) {
|
||||||
return nil, ErrInvalidCommandSequence.
|
case ReturningCommand, BypassCommand:
|
||||||
Withf("%s cannot follow %s", exec, executors[i+1])
|
if i != len(executors)-1 {
|
||||||
|
return nil, ErrInvalidCommandSequence.
|
||||||
|
Withf("a returning / bypass command must be the last command")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &CommandExecutor{
|
|
||||||
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
return Commands(executors), nil
|
||||||
for _, exec := range executors {
|
|
||||||
exec.HandlerFunc(w, r)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
proceed: executors[len(executors)-1].proceed,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Command is purely "bypass" or empty.
|
||||||
func (cmd *Command) isBypass() bool {
|
func (cmd *Command) isBypass() bool {
|
||||||
return cmd.exec == nil
|
if cmd == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch cmd := cmd.exec.(type) {
|
||||||
|
case BypassCommand:
|
||||||
|
return true
|
||||||
|
case Commands:
|
||||||
|
// bypass command is always the last one
|
||||||
|
_, ok := cmd[len(cmd)-1].(BypassCommand)
|
||||||
|
return ok
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cmd *Command) String() string {
|
func (cmd *Command) String() string {
|
||||||
return cmd.raw
|
return cmd.raw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cmd *Command) MarshalJSON() ([]byte, error) {
|
func (cmd *Command) MarshalText() ([]byte, error) {
|
||||||
return []byte("\"" + cmd.String() + "\""), nil
|
return []byte(cmd.String()), nil
|
||||||
}
|
|
||||||
|
|
||||||
func (exec *CommandExecutor) String() string {
|
|
||||||
return exec.directive
|
|
||||||
}
|
}
|
||||||
|
|
140
internal/route/rules/do_test.go
Normal file
140
internal/route/rules/do_test.go
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseCommands(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// bypass tests
|
||||||
|
{
|
||||||
|
name: "bypass_valid",
|
||||||
|
input: "bypass",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bypass_invalid_with_args",
|
||||||
|
input: "bypass /",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// rewrite tests
|
||||||
|
{
|
||||||
|
name: "rewrite_valid",
|
||||||
|
input: "rewrite / /foo/bar",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rewrite_missing_target",
|
||||||
|
input: "rewrite /",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rewrite_too_many_args",
|
||||||
|
input: "rewrite / / /",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rewrite_no_leading_slash",
|
||||||
|
input: "rewrite abc /",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// serve tests
|
||||||
|
{
|
||||||
|
name: "serve_valid",
|
||||||
|
input: "serve /var/www",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "serve_missing_path",
|
||||||
|
input: "serve ",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "serve_too_many_args",
|
||||||
|
input: "serve / / /",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// redirect tests
|
||||||
|
{
|
||||||
|
name: "redirect_valid",
|
||||||
|
input: "redirect /",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "redirect_too_many_args",
|
||||||
|
input: "redirect / /",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// error directive tests
|
||||||
|
{
|
||||||
|
name: "error_valid",
|
||||||
|
input: "error 404 Not\\ Found",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_missing_status_code",
|
||||||
|
input: "error Not\\ Found",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_too_many_args",
|
||||||
|
input: "error 404 Not\\ Found extra",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_no_escaped_space",
|
||||||
|
input: "error 404 Not Found",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error_invalid_status_code",
|
||||||
|
input: "error 123 abc",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// proxy directive tests
|
||||||
|
{
|
||||||
|
name: "proxy_valid",
|
||||||
|
input: "proxy http://localhost:8080",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy_missing_target",
|
||||||
|
input: "proxy",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy_too_many_args",
|
||||||
|
input: "proxy http://localhost:8080 extra",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxy_invalid_url",
|
||||||
|
input: "proxy invalid_url",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// unknown directive test
|
||||||
|
{
|
||||||
|
name: "unknown_directive",
|
||||||
|
input: "unknown /",
|
||||||
|
wantErr: ErrUnknownDirective,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cmd := Command{}
|
||||||
|
err := cmd.Parse(tt.input)
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
ExpectError(t, tt.wantErr, err)
|
||||||
|
} else {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,7 +9,9 @@ var (
|
||||||
ErrInvalidArguments = E.New("invalid arguments")
|
ErrInvalidArguments = E.New("invalid arguments")
|
||||||
ErrInvalidOnTarget = E.New("invalid `rule.on` target")
|
ErrInvalidOnTarget = E.New("invalid `rule.on` target")
|
||||||
ErrInvalidCommandSequence = E.New("invalid command sequence")
|
ErrInvalidCommandSequence = E.New("invalid command sequence")
|
||||||
|
ErrInvalidSetTarget = E.New("invalid `rule.set` target")
|
||||||
|
|
||||||
|
ErrExpectNoArg = ErrInvalidArguments.Withf("expect no arg")
|
||||||
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
|
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
|
||||||
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")
|
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")
|
||||||
)
|
)
|
||||||
|
|
142
internal/route/rules/fields.go
Normal file
142
internal/route/rules/fields.go
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
FieldHandler struct {
|
||||||
|
set, add, remove CommandHandler
|
||||||
|
}
|
||||||
|
FieldModifier string
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModFieldSet FieldModifier = "set"
|
||||||
|
ModFieldAdd FieldModifier = "add"
|
||||||
|
ModFieldRemove FieldModifier = "remove"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
FieldHeader = "header"
|
||||||
|
FieldQuery = "query"
|
||||||
|
FieldCookie = "cookie"
|
||||||
|
)
|
||||||
|
|
||||||
|
var modFields = map[string]struct {
|
||||||
|
help Help
|
||||||
|
validate ValidateFunc
|
||||||
|
builder func(args any) *FieldHandler
|
||||||
|
}{
|
||||||
|
FieldHeader: {
|
||||||
|
help: Help{
|
||||||
|
command: FieldHeader,
|
||||||
|
args: map[string]string{
|
||||||
|
"key": "the header key",
|
||||||
|
"value": "the header value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: toStrTuple,
|
||||||
|
builder: func(args any) *FieldHandler {
|
||||||
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return &FieldHandler{
|
||||||
|
set: StaticCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header()[k] = []string{v}
|
||||||
|
}),
|
||||||
|
add: StaticCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h := w.Header()
|
||||||
|
h[k] = append(h[k], v)
|
||||||
|
}),
|
||||||
|
remove: StaticCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
delete(w.Header(), k)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldQuery: {
|
||||||
|
help: Help{
|
||||||
|
command: FieldQuery,
|
||||||
|
args: map[string]string{
|
||||||
|
"key": "the query key",
|
||||||
|
"value": "the query value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: toStrTuple,
|
||||||
|
builder: func(args any) *FieldHandler {
|
||||||
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return &FieldHandler{
|
||||||
|
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cached.UpdateQueries(r, func(queries url.Values) {
|
||||||
|
queries.Set(k, v)
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cached.UpdateQueries(r, func(queries url.Values) {
|
||||||
|
queries.Add(k, v)
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cached.UpdateQueries(r, func(queries url.Values) {
|
||||||
|
queries.Del(k)
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FieldCookie: {
|
||||||
|
help: Help{
|
||||||
|
command: FieldCookie,
|
||||||
|
args: map[string]string{
|
||||||
|
"key": "the cookie key",
|
||||||
|
"value": "the cookie value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: toStrTuple,
|
||||||
|
builder: func(args any) *FieldHandler {
|
||||||
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return &FieldHandler{
|
||||||
|
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
|
for i, c := range cookies {
|
||||||
|
if c.Name == k {
|
||||||
|
cookies[i].Value = v
|
||||||
|
return cookies
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
|
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
|
index := -1
|
||||||
|
for i, c := range cookies {
|
||||||
|
if c.Name == k {
|
||||||
|
index = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if index != -1 {
|
||||||
|
if len(cookies) == 1 {
|
||||||
|
return []*http.Cookie{}
|
||||||
|
}
|
||||||
|
return append(cookies[:index], cookies[index+1:]...)
|
||||||
|
}
|
||||||
|
return cookies
|
||||||
|
})
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
|
@ -1,37 +1,34 @@
|
||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type RuleOn struct {
|
||||||
RuleOn struct {
|
raw string
|
||||||
raw string
|
checker Checker
|
||||||
check CheckFulfill
|
}
|
||||||
}
|
|
||||||
CheckFulfill func(r *http.Request) bool
|
|
||||||
Checkers []CheckFulfill
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OnHeader = "header"
|
OnHeader = "header"
|
||||||
OnQuery = "query"
|
OnQuery = "query"
|
||||||
OnCookie = "cookie"
|
OnCookie = "cookie"
|
||||||
OnForm = "form"
|
OnForm = "form"
|
||||||
OnPostForm = "postform"
|
OnPostForm = "postform"
|
||||||
OnMethod = "method"
|
OnMethod = "method"
|
||||||
OnPath = "path"
|
OnPath = "path"
|
||||||
OnRemote = "remote"
|
OnRemote = "remote"
|
||||||
|
OnBasicAuth = "basic_auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var checkers = map[string]struct {
|
var checkers = map[string]struct {
|
||||||
help Help
|
help Help
|
||||||
validate ValidateFunc
|
validate ValidateFunc
|
||||||
check func(r *http.Request, args any) bool
|
builder func(args any) CheckFunc
|
||||||
}{
|
}{
|
||||||
OnHeader: {
|
OnHeader: {
|
||||||
help: Help{
|
help: Help{
|
||||||
|
@ -42,8 +39,11 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toStrTuple,
|
validate: toStrTuple,
|
||||||
check: func(r *http.Request, args any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return func(cached Cache, r *http.Request) bool {
|
||||||
|
return r.Header.Get(k) == v
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnQuery: {
|
OnQuery: {
|
||||||
|
@ -55,8 +55,17 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toStrTuple,
|
validate: toStrTuple,
|
||||||
check: func(r *http.Request, args any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return func(cached Cache, r *http.Request) bool {
|
||||||
|
queries := cached.GetQueries(r)[k]
|
||||||
|
for _, query := range queries {
|
||||||
|
if query == v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnCookie: {
|
OnCookie: {
|
||||||
|
@ -68,14 +77,18 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toStrTuple,
|
validate: toStrTuple,
|
||||||
check: func(r *http.Request, args any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
cookies := r.CookiesNamed(args.(StrTuple).First)
|
k, v := args.(*StrTuple).Unpack()
|
||||||
for _, cookie := range cookies {
|
return func(cached Cache, r *http.Request) bool {
|
||||||
if cookie.Value == args.(StrTuple).Second {
|
cookies := cached.GetCookies(r)
|
||||||
return true
|
for _, cookie := range cookies {
|
||||||
|
if cookie.Name == k &&
|
||||||
|
cookie.Value == v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnForm: {
|
OnForm: {
|
||||||
|
@ -87,8 +100,11 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toStrTuple,
|
validate: toStrTuple,
|
||||||
check: func(r *http.Request, args any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
return r.FormValue(args.(StrTuple).First) == args.(StrTuple).Second
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return func(cached Cache, r *http.Request) bool {
|
||||||
|
return r.FormValue(k) == v
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnPostForm: {
|
OnPostForm: {
|
||||||
|
@ -100,8 +116,11 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toStrTuple,
|
validate: toStrTuple,
|
||||||
check: func(r *http.Request, args any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
return r.PostFormValue(args.(StrTuple).First) == args.(StrTuple).Second
|
k, v := args.(*StrTuple).Unpack()
|
||||||
|
return func(cached Cache, r *http.Request) bool {
|
||||||
|
return r.PostFormValue(k) == v
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnMethod: {
|
OnMethod: {
|
||||||
|
@ -112,8 +131,11 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateMethod,
|
validate: validateMethod,
|
||||||
check: func(r *http.Request, method any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
return r.Method == method.(string)
|
method := args.(string)
|
||||||
|
return func(cached Cache, r *http.Request) bool {
|
||||||
|
return r.Method == method
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnPath: {
|
OnPath: {
|
||||||
|
@ -127,12 +149,15 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateURLPath,
|
validate: validateURLPath,
|
||||||
check: func(r *http.Request, globPath any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
reqPath := r.URL.Path
|
pat := args.(string)
|
||||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
return func(cached Cache, r *http.Request) bool {
|
||||||
reqPath = "/" + reqPath
|
reqPath := r.URL.Path
|
||||||
|
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||||
|
reqPath = "/" + reqPath
|
||||||
|
}
|
||||||
|
return strutils.GlobMatch(pat, reqPath)
|
||||||
}
|
}
|
||||||
return strutils.GlobMatch(globPath.(string), reqPath)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnRemote: {
|
OnRemote: {
|
||||||
|
@ -143,16 +168,31 @@ var checkers = map[string]struct {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateCIDR,
|
validate: validateCIDR,
|
||||||
check: func(r *http.Request, cidr any) bool {
|
builder: func(args any) CheckFunc {
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
cidr := args.(types.CIDR)
|
||||||
if err != nil {
|
return func(cached Cache, r *http.Request) bool {
|
||||||
host = r.RemoteAddr
|
ip := cached.GetRemoteIP(r)
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return cidr.Contains(ip)
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(host)
|
},
|
||||||
if ip == nil {
|
},
|
||||||
return false
|
OnBasicAuth: {
|
||||||
|
help: Help{
|
||||||
|
command: OnBasicAuth,
|
||||||
|
args: map[string]string{
|
||||||
|
"username": "the username",
|
||||||
|
"password": "the password encrypted with bcrypt",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
validate: validateUserBCryptPassword,
|
||||||
|
builder: func(args any) CheckFunc {
|
||||||
|
cred := args.(*HashedCrendentials)
|
||||||
|
return func(cached Cache, r *http.Request) bool {
|
||||||
|
return cred.Match(cached.GetBasicAuth(r))
|
||||||
}
|
}
|
||||||
return cidr.(*net.IPNet).Contains(ip)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -162,7 +202,7 @@ func (on *RuleOn) Parse(v string) error {
|
||||||
on.raw = v
|
on.raw = v
|
||||||
|
|
||||||
lines := strutils.SplitLine(v)
|
lines := strutils.SplitLine(v)
|
||||||
checks := make(Checkers, 0, len(lines))
|
checkAnd := make(CheckMatchAll, 0, len(lines))
|
||||||
|
|
||||||
errs := E.NewBuilder("rule.on syntax errors")
|
errs := E.NewBuilder("rule.on syntax errors")
|
||||||
for i, line := range lines {
|
for i, line := range lines {
|
||||||
|
@ -174,10 +214,10 @@ func (on *RuleOn) Parse(v string) error {
|
||||||
errs.Add(err.Subjectf("line %d", i+1))
|
errs.Add(err.Subjectf("line %d", i+1))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
checks = append(checks, parsed.matchOne())
|
checkAnd = append(checkAnd, parsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
on.check = checks.matchAll()
|
on.checker = checkAnd
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,28 +225,28 @@ func (on *RuleOn) String() string {
|
||||||
return on.raw
|
return on.raw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (on *RuleOn) MarshalJSON() ([]byte, error) {
|
func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||||
return []byte("\"" + on.String() + "\""), nil
|
return []byte(on.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOn(line string) (Checkers, E.Error) {
|
func parseOn(line string) (Checker, E.Error) {
|
||||||
ors := strutils.SplitRune(line, '|')
|
ors := strutils.SplitRune(line, '|')
|
||||||
|
|
||||||
if len(ors) > 1 {
|
if len(ors) > 1 {
|
||||||
errs := E.NewBuilder("rule.on syntax errors")
|
errs := E.NewBuilder("rule.on syntax errors")
|
||||||
checks := make([]CheckFulfill, len(ors))
|
checkOr := make(CheckMatchSingle, len(ors))
|
||||||
for i, or := range ors {
|
for i, or := range ors {
|
||||||
curCheckers, err := parseOn(or)
|
curCheckers, err := parseOn(or)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.Add(err)
|
errs.Add(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
checks[i] = curCheckers[0]
|
checkOr[i] = curCheckers.(CheckFunc)
|
||||||
}
|
}
|
||||||
if err := errs.Error(); err != nil {
|
if err := errs.Error(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return checks, nil
|
return checkOr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
subject, args, err := parse(line)
|
subject, args, err := parse(line)
|
||||||
|
@ -224,31 +264,5 @@ func parseOn(line string) (Checkers, E.Error) {
|
||||||
return nil, err.Subject(subject).Withf("%s", checker.help.String())
|
return nil, err.Subject(subject).Withf("%s", checker.help.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return Checkers{
|
return checker.builder(validArgs), nil
|
||||||
func(r *http.Request) bool {
|
|
||||||
return checker.check(r, validArgs)
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (checkers Checkers) matchOne() CheckFulfill {
|
|
||||||
return func(r *http.Request) bool {
|
|
||||||
for _, checker := range checkers {
|
|
||||||
if checker(r) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (checkers Checkers) matchAll() CheckFulfill {
|
|
||||||
return func(r *http.Request) bool {
|
|
||||||
for _, checker := range checkers {
|
|
||||||
if !checker(r) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
89
internal/route/rules/on_test.go
Normal file
89
internal/route/rules/on_test.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseOn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr E.Error
|
||||||
|
}{
|
||||||
|
// header
|
||||||
|
{
|
||||||
|
name: "header_valid",
|
||||||
|
input: "header Connection Upgrade",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "header_invalid",
|
||||||
|
input: "header Connection",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// query
|
||||||
|
{
|
||||||
|
name: "query_valid",
|
||||||
|
input: "query key value",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "query_invalid",
|
||||||
|
input: "query key",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// method
|
||||||
|
{
|
||||||
|
name: "method_valid",
|
||||||
|
input: "method GET",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "method_invalid",
|
||||||
|
input: "method",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// path
|
||||||
|
{
|
||||||
|
name: "path_valid",
|
||||||
|
input: "path /home",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path_invalid",
|
||||||
|
input: "path",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
// remote
|
||||||
|
{
|
||||||
|
name: "remote_valid",
|
||||||
|
input: "remote 127.0.0.1",
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remote_invalid",
|
||||||
|
input: "remote",
|
||||||
|
wantErr: ErrInvalidArguments,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown_target",
|
||||||
|
input: "unknown",
|
||||||
|
wantErr: ErrInvalidOnTarget,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
on := &RuleOn{}
|
||||||
|
err := on.Parse(tt.input)
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
ExpectError(t, tt.wantErr, err)
|
||||||
|
} else {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ var escapedChars = map[rune]rune{
|
||||||
'\'': '\'',
|
'\'': '\'',
|
||||||
'"': '"',
|
'"': '"',
|
||||||
'\\': '\\',
|
'\\': '\\',
|
||||||
|
'$': '$',
|
||||||
' ': ' ',
|
' ': ' ',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,3 +94,13 @@ func TestParser(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkParser(b *testing.B) {
|
||||||
|
const input = `error 403 "Forbidden "foo" "bar""\ baz`
|
||||||
|
for range b.N {
|
||||||
|
_, _, err := parse(input)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ type (
|
||||||
on: method POST | method PUT
|
on: method POST | method PUT
|
||||||
do: error 403 Forbidden
|
do: error 403 Forbidden
|
||||||
*/
|
*/
|
||||||
Rules []Rule
|
Rules []*Rule
|
||||||
/*
|
/*
|
||||||
Rule is a rule for a reverse proxy.
|
Rule is a rule for a reverse proxy.
|
||||||
It do `Do` when `On` matches.
|
It do `Do` when `On` matches.
|
||||||
|
@ -52,52 +53,72 @@ type (
|
||||||
// if no rule matches, the default rule is executed
|
// if no rule matches, the default rule is executed
|
||||||
// if no rule matches and default rule is not set,
|
// if no rule matches and default rule is not set,
|
||||||
// the request is passed to the upstream.
|
// the request is passed to the upstream.
|
||||||
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
|
func (rules Rules) BuildHandler(caller string, up http.Handler) http.HandlerFunc {
|
||||||
var (
|
var defaultRule *Rule
|
||||||
defaultRule Rule
|
|
||||||
defaultRuleIndex int
|
|
||||||
)
|
|
||||||
|
|
||||||
|
nonDefaultRules := make(Rules, 0, len(rules))
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
if rule.Name == "default" {
|
if rule.Name == "default" {
|
||||||
defaultRule = rule
|
defaultRule = rule
|
||||||
defaultRuleIndex = i
|
nonDefaultRules = append(nonDefaultRules, rules[:i]...)
|
||||||
|
nonDefaultRules = append(nonDefaultRules, rules[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rules = append(rules[:defaultRuleIndex], rules[defaultRuleIndex+1:]...)
|
|
||||||
|
|
||||||
// free allocated empty slices
|
|
||||||
// before encapsulating them into the handlerFunc.
|
|
||||||
if len(rules) == 0 {
|
if len(rules) == 0 {
|
||||||
if defaultRule.Do.isBypass() {
|
if defaultRule.Do.isBypass() {
|
||||||
return up.ServeHTTP
|
return up.ServeHTTP
|
||||||
}
|
}
|
||||||
rules = []Rule{}
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cache := NewCache()
|
||||||
|
defer cache.Release()
|
||||||
|
if defaultRule.Do.exec.Handle(cache, w, r) {
|
||||||
|
up.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
hasMatch := false
|
cache := NewCache()
|
||||||
for _, rule := range rules {
|
defer cache.Release()
|
||||||
if rule.On.check(r) {
|
|
||||||
|
for _, rule := range nonDefaultRules {
|
||||||
|
if rule.Check(cache, r) {
|
||||||
if rule.Do.isBypass() {
|
if rule.Do.isBypass() {
|
||||||
up.ServeHTTP(w, r)
|
up.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rule.Do.exec.HandlerFunc(w, r)
|
if !rule.Handle(cache, w, r) {
|
||||||
if !rule.Do.exec.proceed {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hasMatch = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasMatch || defaultRule.Do.isBypass() {
|
// bypass or proceed
|
||||||
|
if defaultRule.Do.isBypass() || defaultRule.Handle(cache, w, r) {
|
||||||
up.ServeHTTP(w, r)
|
up.ServeHTTP(w, r)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultRule.Do.exec.HandlerFunc(w, r)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rules Rules) MarshalJSON() ([]byte, error) {
|
||||||
|
names := make([]string, len(rules))
|
||||||
|
for i, rule := range rules {
|
||||||
|
names[i] = rule.Name
|
||||||
|
}
|
||||||
|
return json.Marshal(names)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *Rule) String() string {
|
||||||
|
return rule.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *Rule) Check(cached Cache, r *http.Request) bool {
|
||||||
|
return rule.On.checker.Check(cached, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *Rule) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
proceed = rule.Do.exec.Handle(cached, w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -3,249 +3,44 @@ package rules
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSubjectArgs(t *testing.T) {
|
|
||||||
t.Run("basic", func(t *testing.T) {
|
|
||||||
subject, args, err := parse("rewrite / /foo/bar")
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
ExpectEqual(t, subject, "rewrite")
|
|
||||||
ExpectDeepEqual(t, args, []string{"/", "/foo/bar"})
|
|
||||||
})
|
|
||||||
t.Run("with quotes", func(t *testing.T) {
|
|
||||||
subject, args, err := parse(`error 403 "Forbidden 'foo' 'bar'."`)
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
ExpectEqual(t, subject, "error")
|
|
||||||
ExpectDeepEqual(t, args, []string{"403", "Forbidden 'foo' 'bar'."})
|
|
||||||
})
|
|
||||||
t.Run("with escaped", func(t *testing.T) {
|
|
||||||
subject, args, err := parse(`error 403 Forbidden\ \"foo\"\ \"bar\".`)
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
ExpectEqual(t, subject, "error")
|
|
||||||
ExpectDeepEqual(t, args, []string{"403", "Forbidden \"foo\" \"bar\"."})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseCommands(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
// bypass tests
|
|
||||||
{
|
|
||||||
name: "bypass_valid",
|
|
||||||
input: "bypass",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bypass_invalid_with_args",
|
|
||||||
input: "bypass /",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// rewrite tests
|
|
||||||
{
|
|
||||||
name: "rewrite_valid",
|
|
||||||
input: "rewrite / /foo/bar",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "rewrite_missing_target",
|
|
||||||
input: "rewrite /",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "rewrite_too_many_args",
|
|
||||||
input: "rewrite / / /",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "rewrite_no_leading_slash",
|
|
||||||
input: "rewrite abc /",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// serve tests
|
|
||||||
{
|
|
||||||
name: "serve_valid",
|
|
||||||
input: "serve /var/www",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "serve_missing_path",
|
|
||||||
input: "serve ",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "serve_too_many_args",
|
|
||||||
input: "serve / / /",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// redirect tests
|
|
||||||
{
|
|
||||||
name: "redirect_valid",
|
|
||||||
input: "redirect /",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "redirect_too_many_args",
|
|
||||||
input: "redirect / /",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// error directive tests
|
|
||||||
{
|
|
||||||
name: "error_valid",
|
|
||||||
input: "error 404 Not\\ Found",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error_missing_status_code",
|
|
||||||
input: "error Not\\ Found",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error_too_many_args",
|
|
||||||
input: "error 404 Not\\ Found extra",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error_no_escaped_space",
|
|
||||||
input: "error 404 Not Found",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error_invalid_status_code",
|
|
||||||
input: "error 123 abc",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// proxy directive tests
|
|
||||||
{
|
|
||||||
name: "proxy_valid",
|
|
||||||
input: "proxy localhost:8080",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy_missing_target",
|
|
||||||
input: "proxy",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy_too_many_args",
|
|
||||||
input: "proxy localhost:8080 extra",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "proxy_invalid_url",
|
|
||||||
input: "proxy :invalid_url",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// unknown directive test
|
|
||||||
{
|
|
||||||
name: "unknown_directive",
|
|
||||||
input: "unknown /",
|
|
||||||
wantErr: ErrUnknownDirective,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cmd := Command{}
|
|
||||||
err := cmd.Parse(tt.input)
|
|
||||||
if tt.wantErr != nil {
|
|
||||||
ExpectError(t, tt.wantErr, err)
|
|
||||||
} else {
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseOn(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
wantErr E.Error
|
|
||||||
}{
|
|
||||||
// header
|
|
||||||
{
|
|
||||||
name: "header_valid",
|
|
||||||
input: "header Connection Upgrade",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "header_invalid",
|
|
||||||
input: "header Connection",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// query
|
|
||||||
{
|
|
||||||
name: "query_valid",
|
|
||||||
input: "query key value",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "query_invalid",
|
|
||||||
input: "query key",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// method
|
|
||||||
{
|
|
||||||
name: "method_valid",
|
|
||||||
input: "method GET",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "method_invalid",
|
|
||||||
input: "method",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// path
|
|
||||||
{
|
|
||||||
name: "path_valid",
|
|
||||||
input: "path /home",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "path_invalid",
|
|
||||||
input: "path",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
// remote
|
|
||||||
{
|
|
||||||
name: "remote_valid",
|
|
||||||
input: "remote 127.0.0.1",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote_invalid",
|
|
||||||
input: "remote",
|
|
||||||
wantErr: ErrInvalidArguments,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "unknown_target",
|
|
||||||
input: "unknown",
|
|
||||||
wantErr: ErrInvalidOnTarget,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
on := &RuleOn{}
|
|
||||||
err := on.Parse(tt.input)
|
|
||||||
if tt.wantErr != nil {
|
|
||||||
ExpectError(t, tt.wantErr, err)
|
|
||||||
} else {
|
|
||||||
ExpectNoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRule(t *testing.T) {
|
func TestParseRule(t *testing.T) {
|
||||||
// test := map[string]any{
|
test := []map[string]any{
|
||||||
// "name": "test",
|
{
|
||||||
// "on": "method GET",
|
"name": "test",
|
||||||
// "do": "bypass",
|
"on": "method POST",
|
||||||
// }
|
"do": "error 403 Forbidden",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "auth",
|
||||||
|
"on": `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`,
|
||||||
|
"do": "bypass",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "default",
|
||||||
|
"do": "require_basic_auth any_realm",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules struct {
|
||||||
|
Rules Rules
|
||||||
|
}
|
||||||
|
err := utils.Deserialize(utils.SerializedObject{"rules": test}, &rules)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
ExpectEqual(t, len(rules.Rules), len(test))
|
||||||
|
ExpectEqual(t, rules.Rules[0].Name, "test")
|
||||||
|
ExpectEqual(t, rules.Rules[0].On.String(), "method POST")
|
||||||
|
ExpectEqual(t, rules.Rules[0].Do.String(), "error 403 Forbidden")
|
||||||
|
|
||||||
|
ExpectEqual(t, rules.Rules[1].Name, "auth")
|
||||||
|
ExpectEqual(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`)
|
||||||
|
ExpectEqual(t, rules.Rules[1].Do.String(), "bypass")
|
||||||
|
|
||||||
|
ExpectEqual(t, rules.Rules[2].Name, "default")
|
||||||
|
ExpectEqual(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: real tests.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -11,19 +12,31 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
ValidateFunc func(args []string) (any, E.Error)
|
ValidateFunc func(args []string) (any, E.Error)
|
||||||
StrTuple struct {
|
Tuple[T1, T2 any] struct {
|
||||||
First, Second string
|
First T1
|
||||||
|
Second T2
|
||||||
}
|
}
|
||||||
|
StrTuple = Tuple[string, string]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
||||||
|
return t.First, t.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tuple[T1, T2]) String() string {
|
||||||
|
return fmt.Sprintf("%v:%v", t.First, t.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// toStrTuple returns *StrTuple.
|
||||||
func toStrTuple(args []string) (any, E.Error) {
|
func toStrTuple(args []string) (any, E.Error) {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return nil, ErrExpectTwoArgs
|
return nil, ErrExpectTwoArgs
|
||||||
}
|
}
|
||||||
return StrTuple{args[0], args[1]}, nil
|
return &StrTuple{args[0], args[1]}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateURL returns types.URL with the URL validated.
|
||||||
func validateURL(args []string) (any, E.Error) {
|
func validateURL(args []string) (any, E.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
|
@ -35,6 +48,7 @@ func validateURL(args []string) (any, E.Error) {
|
||||||
return u, nil
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateAbsoluteURL returns types.URL with the URL validated.
|
||||||
func validateAbsoluteURL(args []string) (any, E.Error) {
|
func validateAbsoluteURL(args []string) (any, E.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
|
@ -52,6 +66,7 @@ func validateAbsoluteURL(args []string) (any, E.Error) {
|
||||||
return u, nil
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateCIDR returns types.CIDR with the CIDR validated.
|
||||||
func validateCIDR(args []string) (any, E.Error) {
|
func validateCIDR(args []string) (any, E.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
|
@ -66,6 +81,7 @@ func validateCIDR(args []string) (any, E.Error) {
|
||||||
return cidr, nil
|
return cidr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateURLPath returns string with the path validated.
|
||||||
func validateURLPath(args []string) (any, E.Error) {
|
func validateURLPath(args []string) (any, E.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
|
@ -86,6 +102,7 @@ func validateURLPath(args []string) (any, E.Error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateURLPaths returns []string with each element validated.
|
||||||
func validateURLPaths(paths []string) (any, E.Error) {
|
func validateURLPaths(paths []string) (any, E.Error) {
|
||||||
errs := E.NewBuilder("invalid url paths")
|
errs := E.NewBuilder("invalid url paths")
|
||||||
for i, p := range paths {
|
for i, p := range paths {
|
||||||
|
@ -102,6 +119,7 @@ func validateURLPaths(paths []string) (any, E.Error) {
|
||||||
return paths, nil
|
return paths, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateFSPath returns string with the path validated.
|
||||||
func validateFSPath(args []string) (any, E.Error) {
|
func validateFSPath(args []string) (any, E.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
|
@ -113,6 +131,7 @@ func validateFSPath(args []string) (any, E.Error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateMethod returns string with the method validated.
|
||||||
func validateMethod(args []string) (any, E.Error) {
|
func validateMethod(args []string) (any, E.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
|
@ -123,3 +142,31 @@ func validateMethod(args []string) (any, E.Error) {
|
||||||
}
|
}
|
||||||
return method, nil
|
return method, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||||
|
func validateUserBCryptPassword(args []string) (any, E.Error) {
|
||||||
|
if len(args) != 2 {
|
||||||
|
return nil, ErrExpectTwoArgs
|
||||||
|
}
|
||||||
|
return BCryptCrendentials(args[0], []byte(args[1])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateModField returns CommandHandler with the field validated.
|
||||||
|
func validateModField(mod FieldModifier, args []string) (CommandHandler, E.Error) {
|
||||||
|
setField, ok := modFields[args[0]]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidSetTarget.Subject(args[0])
|
||||||
|
}
|
||||||
|
validArgs, err := setField.validate(args[1:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err.Withf(setField.help.String())
|
||||||
|
}
|
||||||
|
modder := setField.builder(validArgs)
|
||||||
|
switch mod {
|
||||||
|
case ModFieldAdd:
|
||||||
|
return modder.add, nil
|
||||||
|
case ModFieldRemove:
|
||||||
|
return modder.remove, nil
|
||||||
|
}
|
||||||
|
return modder.set, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue