updated implementation of rules

This commit is contained in:
yusing 2025-01-08 13:50:10 +08:00
parent bc1702e6cf
commit a98b2bb71a
12 changed files with 775 additions and 371 deletions

View file

@ -0,0 +1,20 @@
package http
import "net/http"
var validMethods = map[string]struct{}{
http.MethodGet: {},
http.MethodHead: {},
http.MethodPost: {},
http.MethodPut: {},
http.MethodPatch: {},
http.MethodDelete: {},
http.MethodConnect: {},
http.MethodOptions: {},
http.MethodTrace: {},
}
func IsMethodValid(method string) bool {
_, ok := validMethods[method]
return ok
}

View file

@ -5,3 +5,7 @@ import "net/http"
func IsSuccess(status int) bool { func IsSuccess(status int) bool {
return status >= http.StatusOK && status < http.StatusMultipleChoices return status >= http.StatusOK && status < http.StatusMultipleChoices
} }
func IsStatusCodeValid(status int) bool {
return http.StatusText(status) != ""
}

View file

@ -8,6 +8,11 @@ import (
//nolint:recvcheck //nolint:recvcheck
type CIDR net.IPNet type CIDR net.IPNet
func ParseCIDR(v string) (cidr CIDR, err error) {
err = cidr.Parse(v)
return
}
func (cidr *CIDR) Parse(v string) error { func (cidr *CIDR) Parse(v string) error {
if !strings.Contains(v, "/") { if !strings.Contains(v, "/") {
v += "/32" // single IP v += "/32" // single IP

164
internal/route/rules/do.go Normal file
View file

@ -0,0 +1,164 @@
package rules
import (
"net/http"
"path"
"strconv"
"strings"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
)
type (
Command struct {
raw string
CommandExecutor
}
CommandExecutor struct {
http.HandlerFunc
proceed bool
}
)
const (
CommandRewrite = "rewrite"
CommandServe = "serve"
CommandProxy = "proxy"
CommandRedirect = "redirect"
CommandError = "error"
CommandBypass = "bypass"
)
var commands = map[string]struct {
validate ValidateFunc
build func(args any) CommandExecutor
}{
CommandRewrite: {
validate: func(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return validateURLPaths(args)
},
build: func(args any) CommandExecutor {
a := args.([]string)
orig, repl := a[0], a[1]
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
if len(r.URL.Path) > 0 && r.URL.Path[0] != '/' {
r.URL.Path = "/" + r.URL.Path
}
r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1)
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.String()
},
proceed: true,
}
},
},
CommandServe: {
validate: validateFSPath,
build: func(args any) CommandExecutor {
root := args.(string)
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
},
proceed: false,
}
},
},
CommandRedirect: {
validate: validateURL,
build: func(args any) CommandExecutor {
target := args.(types.URL).String()
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
},
proceed: false,
}
},
},
CommandError: {
validate: func(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
codeStr, text := args[0], args[1]
code, err := strconv.Atoi(codeStr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if !gphttp.IsStatusCodeValid(code) {
return nil, ErrInvalidArguments.Subject(codeStr)
}
return []any{code, text}, nil
},
build: func(args any) CommandExecutor {
a := args.([]any)
code, text := a[0].(int), a[1].(string)
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, text, code)
},
proceed: false,
}
},
},
CommandProxy: {
validate: validateURL,
build: func(args any) CommandExecutor {
target := args.(types.URL)
if target.Scheme == "" {
target.Scheme = "http"
}
rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport)
return CommandExecutor{
HandlerFunc: rp.ServeHTTP,
proceed: false,
}
},
},
}
func (cmd *Command) Parse(v string) error {
cmd.raw = v
directive, args, err := parse(v)
if err != nil {
return err
}
if directive == CommandBypass {
if len(args) != 0 {
return ErrInvalidArguments.Subject(directive)
}
return nil
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
validArgs, err := builder.validate(args)
if err != nil {
return err.Subject(directive)
}
cmd.CommandExecutor = builder.build(validArgs)
return nil
}
func (cmd *Command) isBypass() bool {
return cmd.HandlerFunc == nil
}
func (cmd *Command) String() string {
return cmd.raw
}
func (cmd *Command) MarshalJSON() ([]byte, error) {
return []byte("\"" + cmd.String() + "\""), nil
}

View file

@ -0,0 +1,14 @@
package rules
import E "github.com/yusing/go-proxy/internal/error"
var (
ErrUnterminatedQuotes = E.New("unterminated quotes")
ErrUnsupportedEscapeChar = E.New("unsupported escape char")
ErrUnknownDirective = E.New("unknown directive")
ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidOnTarget = E.New("invalid `rule.on` target")
ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg")
ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args")
)

166
internal/route/rules/on.go Normal file
View file

@ -0,0 +1,166 @@
package rules
import (
"net"
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
RuleOn struct {
raw string
check CheckFulfill
}
CheckFulfill func(r *http.Request) bool
Checkers []CheckFulfill
)
const (
OnHeader = "header"
OnQuery = "query"
OnMethod = "method"
OnPath = "path"
OnRemote = "remote"
)
var checkers = map[string]struct {
validate ValidateFunc
check func(r *http.Request, args any) bool
}{
OnHeader: { // header <key> <value>
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second
},
},
OnQuery: { // query <key> <value>
validate: toStrTuple,
check: func(r *http.Request, args any) bool {
return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second
},
},
OnMethod: { // method <method>
validate: validateMethod,
check: func(r *http.Request, method any) bool {
return r.Method == method.(string)
},
},
OnPath: { // path <path>
validate: validateURLPath,
check: func(r *http.Request, globPath any) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
}
return strutils.GlobMatch(globPath.(string), reqPath)
},
},
OnRemote: { // remote <ip|cidr>
validate: validateCIDR,
check: func(r *http.Request, cidr any) bool {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return cidr.(*net.IPNet).Contains(ip)
},
},
}
func (on *RuleOn) Parse(v string) error {
on.raw = v
lines := strutils.SplitLine(v)
checks := make(Checkers, 0, len(lines))
errs := E.NewBuilder("rule.on syntax errors")
for i, line := range lines {
parsed, err := parseOn(line)
if err != nil {
errs.Add(err.Subjectf("line %d", i+1))
continue
}
checks = append(checks, parsed.matchOne())
}
on.check = checks.matchAll()
return errs.Error()
}
func (on *RuleOn) String() string {
return on.raw
}
func (on *RuleOn) MarshalJSON() ([]byte, error) {
return []byte("\"" + on.String() + "\""), nil
}
func parseOn(line string) (Checkers, E.Error) {
ors := strutils.SplitRune(line, '|')
if len(ors) > 1 {
errs := E.NewBuilder("rule.on syntax errors")
checks := make([]CheckFulfill, len(ors))
for i, or := range ors {
curCheckers, err := parseOn(or)
if err != nil {
errs.Add(err)
continue
}
checks[i] = curCheckers[0]
}
if err := errs.Error(); err != nil {
return nil, err
}
return checks, nil
}
subject, args, err := parse(line)
if err != nil {
return nil, err
}
checker, ok := checkers[subject]
if !ok {
return nil, ErrInvalidOnTarget.Subject(subject)
}
validArgs, err := checker.validate(args)
if err != nil {
return nil, err.Subject(subject)
}
return Checkers{
func(r *http.Request) bool {
return checker.check(r, validArgs)
},
}, nil
}
func (checkers Checkers) matchOne() CheckFulfill {
return func(r *http.Request) bool {
for _, checker := range checkers {
if checker(r) {
return true
}
}
return false
}
}
func (checkers Checkers) matchAll() CheckFulfill {
return func(r *http.Request) bool {
for _, checker := range checkers {
if !checker(r) {
return false
}
}
return true
}
}

View file

@ -0,0 +1,78 @@
package rules
import (
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
var escapedChars = map[rune]rune{
'n': '\n',
't': '\t',
'r': '\r',
'\'': '\'',
'"': '"',
' ': ' ',
}
// parse expression to subject and args
// with support for quotes and escaped chars, e.g.
//
// error 403 "Forbidden 'foo' 'bar'"
// error 403 Forbidden\ \"foo\"\ \"bar\".
func parse(v string) (subject string, args []string, err E.Error) {
v = strings.TrimSpace(v)
var buf strings.Builder
escaped := false
quotes := make([]rune, 0, 4)
flush := func() {
if subject == "" {
subject = buf.String()
} else {
args = append(args, buf.String())
}
buf.Reset()
}
for _, r := range v {
if escaped {
if ch, ok := escapedChars[r]; ok {
buf.WriteRune(ch)
} else {
err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r)
return
}
escaped = false
continue
}
switch r {
case '\\':
escaped = true
continue
case '"', '\'':
switch {
case len(quotes) > 0 && quotes[len(quotes)-1] == r:
quotes = quotes[:len(quotes)-1]
if len(quotes) == 0 {
flush()
} else {
buf.WriteRune(r)
}
case len(quotes) == 0:
quotes = append(quotes, r)
default:
buf.WriteRune(r)
}
case ' ':
flush()
default:
buf.WriteRune(r)
}
}
if len(quotes) > 0 {
err = ErrUnterminatedQuotes
} else {
flush()
}
return
}

View file

@ -0,0 +1,110 @@
package rules
import (
"net/http"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
)
type (
/*
Example:
proxy.app1.rules: |
- name: default
do: |
rewrite / /index.html
serve /var/www/goaccess
- name: ws
on: |
header Connection Upgrade
header Upgrade websocket
do: bypass
proxy.app2.rules: |
- name: default
do: bypass
- name: block POST and PUT
on: method POST | method PUT
do: error 403 Forbidden
*/
Rules []Rule
/*
Rule is a rule for a reverse proxy.
It do `Do` when `On` matches.
A rule can have multiple lines of on.
All lines of on must match,
but each line can have multiple checks that
one match means this line is matched.
*/
Rule struct {
Name string `json:"name" validate:"required,unique"`
On RuleOn `json:"on"`
Do Command `json:"do"`
}
)
// BuildHandler returns a http.HandlerFunc that implements the rules.
//
// Bypass rules are executed first
// if a bypass rule matches,
// the request is passed to the upstream and no more rules are executed.
//
// Other rules are executed later
// if no rule matches, the default rule is executed
// if no rule matches and default rule is not set,
// the request is passed to the upstream.
func (rules Rules) BuildHandler(up *reverseproxy.ReverseProxy) http.HandlerFunc {
// move bypass rules to the front.
bypassRules := make(Rules, 0, len(rules))
otherRules := make(Rules, 0, len(rules))
var defaultRule Rule
for _, rule := range rules {
switch {
case rule.Do.isBypass():
bypassRules = append(bypassRules, rule)
case rule.Name == "default":
defaultRule = rule
default:
otherRules = append(otherRules, rule)
}
}
// free allocated empty slices
// before encapsulating them into the handlerFunc.
if len(bypassRules) == 0 {
bypassRules = []Rule{}
}
if len(otherRules) == 0 {
otherRules = []Rule{defaultRule}
}
return func(w http.ResponseWriter, r *http.Request) {
for _, rule := range bypassRules {
if rule.On.check(r) {
up.ServeHTTP(w, r)
return
}
}
hasMatch := false
for _, rule := range otherRules {
if rule.On.check(r) {
hasMatch = true
rule.Do.HandlerFunc(w, r)
if !rule.Do.proceed {
return
}
}
}
if hasMatch || defaultRule.Do.isBypass() {
up.ServeHTTP(w, r)
return
}
defaultRule.Do.HandlerFunc(w, r)
}
}

View file

@ -1,13 +1,14 @@
package types package rules
import ( import (
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestParseSubjectArgs(t *testing.T) { func TestParseSubjectArgs(t *testing.T) {
t.Run("without quotes", func(t *testing.T) { t.Run("basic", func(t *testing.T) {
subject, args, err := parse("rewrite / /foo/bar") subject, args, err := parse("rewrite / /foo/bar")
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, subject, "rewrite") ExpectEqual(t, subject, "rewrite")
@ -60,6 +61,11 @@ func TestParseCommands(t *testing.T) {
input: "rewrite / / /", input: "rewrite / / /",
wantErr: ErrInvalidArguments, wantErr: ErrInvalidArguments,
}, },
{
name: "rewrite_no_leading_slash",
input: "rewrite abc /",
wantErr: ErrInvalidArguments,
},
// serve tests // serve tests
{ {
name: "serve_valid", name: "serve_valid",
@ -104,10 +110,15 @@ func TestParseCommands(t *testing.T) {
wantErr: ErrInvalidArguments, wantErr: ErrInvalidArguments,
}, },
{ {
name: "error_unescaped_space", name: "error_no_escaped_space",
input: "error 404 Not Found", input: "error 404 Not Found",
wantErr: ErrInvalidArguments, wantErr: ErrInvalidArguments,
}, },
{
name: "error_invalid_status_code",
input: "error 123 abc",
wantErr: ErrInvalidArguments,
},
// proxy directive tests // proxy directive tests
{ {
name: "proxy_valid", name: "proxy_valid",
@ -124,6 +135,11 @@ func TestParseCommands(t *testing.T) {
input: "proxy localhost:8080 extra", input: "proxy localhost:8080 extra",
wantErr: ErrInvalidArguments, wantErr: ErrInvalidArguments,
}, },
{
name: "proxy_invalid_url",
input: "proxy :invalid_url",
wantErr: ErrInvalidArguments,
},
// unknown directive test // unknown directive test
{ {
name: "unknown_directive", name: "unknown_directive",
@ -144,3 +160,92 @@ func TestParseCommands(t *testing.T) {
}) })
} }
} }
func TestParseOn(t *testing.T) {
tests := []struct {
name string
input string
wantErr E.Error
}{
// header
{
name: "header_valid",
input: "header Connection Upgrade",
wantErr: nil,
},
{
name: "header_invalid",
input: "header Connection",
wantErr: ErrInvalidArguments,
},
// query
{
name: "query_valid",
input: "query key value",
wantErr: nil,
},
{
name: "query_invalid",
input: "query key",
wantErr: ErrInvalidArguments,
},
// method
{
name: "method_valid",
input: "method GET",
wantErr: nil,
},
{
name: "method_invalid",
input: "method",
wantErr: ErrInvalidArguments,
},
// path
{
name: "path_valid",
input: "path /home",
wantErr: nil,
},
{
name: "path_invalid",
input: "path",
wantErr: ErrInvalidArguments,
},
// remote
{
name: "remote_valid",
input: "remote 127.0.0.1",
wantErr: nil,
},
{
name: "remote_invalid",
input: "remote",
wantErr: ErrInvalidArguments,
},
{
name: "unknown_target",
input: "unknown",
wantErr: ErrInvalidOnTarget,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
on := &RuleOn{}
err := on.Parse(tt.input)
if tt.wantErr != nil {
ExpectError(t, tt.wantErr, err)
} else {
ExpectNoError(t, err)
}
})
}
}
func TestParseRule(t *testing.T) {
// test := map[string]any{
// "name": "test",
// "on": "method GET",
// "do": "bypass",
// }
}

View file

@ -0,0 +1,104 @@
package rules
import (
"os"
"path"
"strings"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
type (
ValidateFunc func(args []string) (any, E.Error)
StrTuple struct {
First, Second string
}
)
func toStrTuple(args []string) (any, E.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return StrTuple{args[0], args[1]}, nil
}
func validateURL(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
u, err := types.ParseURL(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return u, nil
}
func validateCIDR(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
if !strings.Contains(args[0], "/") {
args[0] += "/32"
}
cidr, err := types.ParseCIDR(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return cidr, nil
}
func validateURLPath(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
p := args[0]
p, _, _ = strings.Cut(p, "#")
p = path.Clean(p)
if len(p) == 0 {
return "/", nil
}
if p[0] != '/' {
return nil, ErrInvalidArguments.Withf("must start with /")
}
return p, nil
}
func validateURLPaths(paths []string) (any, E.Error) {
errs := E.NewBuilder("invalid url paths")
for i, p := range paths {
val, err := validateURLPath([]string{p})
if err != nil {
errs.Add(err.Subject(p))
continue
}
paths[i] = val.(string)
}
if err := errs.Error(); err != nil {
return nil, err
}
return paths, nil
}
func validateFSPath(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
p := path.Clean(args[0])
if _, err := os.Stat(p); err != nil {
return nil, ErrInvalidArguments.With(err)
}
return p, nil
}
func validateMethod(args []string) (any, E.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
method := strings.ToUpper(args[0])
if !gphttp.IsMethodValid(method) {
return nil, ErrInvalidArguments.Subject(method)
}
return method, nil
}

View file

@ -12,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/accesslog"
loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/rules"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
@ -30,7 +31,7 @@ type (
Port string `json:"port,omitempty"` Port string `json:"port,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"` NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns []string `json:"path_patterns,omitempty"` PathPatterns []string `json:"path_patterns,omitempty"`
Rules Rules `json:"rules,omitempty"` Rules rules.Rules `json:"rules,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalance.Config `json:"load_balance,omitempty"` LoadBalance *loadbalance.Config `json:"load_balance,omitempty"`
Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"` Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"`

View file

@ -1,367 +0,0 @@
package types
import (
"net/http"
"path"
"strconv"
"strings"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type (
Rules []Rule
Rule struct {
Name string `json:"name" validate:"required,unique"`
On RuleOn `json:"on"`
Do Command `json:"do"`
}
RuleOn struct {
raw string
checkers []CheckFulfill
}
Command struct {
raw string
CommandExecutor
}
CheckFulfill func(r *http.Request) bool
RequestObjectRetriever struct {
expectedArgs int
retrieve func(r *http.Request, args []string) string
equal func(v, want string) bool
}
CommandExecutor struct {
http.HandlerFunc
proceed bool
}
CommandBuilder struct {
expectedArgs int
build func(args []string) CommandExecutor
}
)
/*
proxy.app1.rules: |
- name: default
do: |
rewrite / /index.html
serve /var/www/goaccess
- name: ws
on: |
header Connection upgrade
header Upgrade websocket
do: proxy $upstream_url
*/
var (
ErrUnterminatedQuotes = E.New("unterminated quotes")
ErrUnsupportedEscapeChar = E.New("unsupported escape char")
ErrUnknownDirective = E.New("unknown directive")
ErrInvalidArguments = E.New("invalid arguments")
ErrInvalidCriteria = E.New("invalid criteria")
ErrInvalidCriteriaTarget = E.New("invalid criteria target")
)
var retrievers = map[string]RequestObjectRetriever{
"header": {1, func(r *http.Request, args []string) string {
return r.Header.Get(args[0])
}, nil},
"query": {1, func(r *http.Request, args []string) string {
return r.URL.Query().Get(args[0])
}, nil},
"method": {0, func(r *http.Request, _ []string) string {
return r.Method
}, nil},
"path": {0, func(r *http.Request, _ []string) string {
return r.URL.Path
}, func(v, want string) bool {
return strutils.GlobMatch(want, v)
}},
"remote": {0, func(r *http.Request, _ []string) string {
return r.RemoteAddr
}, nil},
}
var commands = map[string]CommandBuilder{
"rewrite": {2, func(args []string) CommandExecutor {
orig, repl := args[0], args[1]
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.Replace(r.URL.Path, orig, repl, 1)
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.String()
},
proceed: true,
}
}},
"serve": {1, func(args []string) CommandExecutor {
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path.Join(args[0], path.Clean(r.URL.Path)))
},
proceed: false,
}
}},
"redirect": {1, func(args []string) CommandExecutor {
target := args[0]
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
},
proceed: false,
}
}},
"error": {2, func(args []string) CommandExecutor {
codeStr, text := args[0], args[1]
code, err := strconv.Atoi(codeStr)
if err != nil {
code = http.StatusNotFound
}
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
http.Error(w, text, code)
},
proceed: false,
}
}},
"proxy": {1, func(args []string) CommandExecutor {
target := args[0]
return CommandExecutor{
HandlerFunc: func(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = "http"
r.URL.Host = target
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.String()
},
proceed: true,
}
}},
}
var escapedChars = map[rune]rune{
'n': '\n',
't': '\t',
'r': '\r',
'\'': '\'',
'"': '"',
' ': ' ',
}
// BuildHandler returns a http.HandlerFunc that implements the rules.
//
// Bypass rules are executed first
// if a bypass rule matches,
// the request is passed to the upstream and no more rules are executed.
//
// Other rules are executed later
// if no rule matches, the default rule is executed
// if no rule matches and default rule is not set,
// the request is passed to the upstream.
func (rules Rules) BuildHandler(up *gphttp.ReverseProxy) http.HandlerFunc {
// move bypass rules to the front.
bypassRules := make(Rules, 0, len(rules))
otherRules := make(Rules, 0, len(rules))
var defaultRule Rule
for _, rule := range rules {
switch {
case rule.Do.isBypass():
bypassRules = append(bypassRules, rule)
case rule.Name == "default":
defaultRule = rule
default:
otherRules = append(otherRules, rule)
}
}
// free allocated empty slices
// before passing them to the handler.
if len(bypassRules) == 0 {
bypassRules = []Rule{}
}
if len(otherRules) == 0 {
otherRules = []Rule{defaultRule}
}
return func(w http.ResponseWriter, r *http.Request) {
hasMatch := false
for _, rule := range bypassRules {
if rule.On.MatchAll(r) {
up.ServeHTTP(w, r)
return
}
}
for _, rule := range otherRules {
if rule.On.MatchAll(r) {
hasMatch = true
rule.Do.HandlerFunc(w, r)
if !rule.Do.proceed {
return
}
}
}
if hasMatch || defaultRule.Do.isBypass() {
up.ServeHTTP(w, r)
return
}
defaultRule.Do.HandlerFunc(w, r)
if !defaultRule.Do.proceed {
return
}
}
}
// parse line to subject and args
// with support for quotes and escaped chars, e.g.
//
// error 403 "Forbidden 'foo' 'bar'"
// error 403 Forbidden\ \"foo\"\ \"bar\".
func parse(v string) (subject string, args []string, err E.Error) {
v = strings.TrimSpace(v)
var buf strings.Builder
escaped := false
quotes := make([]rune, 0, 4)
flush := func() {
if subject == "" {
subject = buf.String()
} else {
args = append(args, buf.String())
}
buf.Reset()
}
for _, r := range v {
if escaped {
if ch, ok := escapedChars[r]; ok {
buf.WriteRune(ch)
} else {
err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r)
return
}
escaped = false
continue
}
switch r {
case '\\':
escaped = true
continue
case '"', '\'':
switch {
case len(quotes) > 0 && quotes[len(quotes)-1] == r:
quotes = quotes[:len(quotes)-1]
if len(quotes) == 0 {
flush()
} else {
buf.WriteRune(r)
}
case len(quotes) == 0:
quotes = append(quotes, r)
default:
buf.WriteRune(r)
}
case ' ':
flush()
default:
buf.WriteRune(r)
}
}
if len(quotes) > 0 {
err = ErrUnterminatedQuotes
} else {
flush()
}
return
}
func (on *RuleOn) Parse(v string) E.Error {
lines := strutils.SplitLine(v)
on.checkers = make([]CheckFulfill, 0, len(lines))
on.raw = v
errs := E.NewBuilder("rule.on syntax errors")
for i, line := range lines {
subject, args, err := parse(line)
if err != nil {
errs.Add(err.Subjectf("line %d", i+1))
continue
}
retriever, ok := retrievers[subject]
if !ok {
errs.Add(ErrInvalidCriteriaTarget.Subject(subject).Subjectf("line %d", i+1))
continue
}
nArgs := retriever.expectedArgs
if len(args) != nArgs+1 {
errs.Add(ErrInvalidArguments.Subject(subject).Subjectf("line %d", i+1))
continue
}
equal := retriever.equal
if equal == nil {
equal = func(a, b string) bool {
return a == b
}
}
on.checkers = append(on.checkers, func(r *http.Request) bool {
return equal(retriever.retrieve(r, args[:nArgs]), args[nArgs])
})
}
return errs.Error()
}
func (on *RuleOn) MatchAll(r *http.Request) bool {
for _, match := range on.checkers {
if !match(r) {
return false
}
}
return true
}
func (cmd *Command) Parse(v string) E.Error {
cmd.raw = v
directive, args, err := parse(v)
if err != nil {
return err
}
if directive == "bypass" {
if len(args) != 0 {
return ErrInvalidArguments.Subject(directive)
}
return nil
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
if len(args) != builder.expectedArgs {
return ErrInvalidArguments.Subject(directive)
}
cmd.CommandExecutor = builder.build(args)
return nil
}
func (cmd *Command) isBypass() bool {
return cmd.HandlerFunc == nil
}
func (on *RuleOn) String() string {
return on.raw
}
func (on *RuleOn) MarshalJSON() ([]byte, error) {
return []byte("\"" + on.String() + "\""), nil
}
func (cmd *Command) String() string {
return cmd.raw
}
func (cmd *Command) MarshalJSON() ([]byte, error) {
return []byte("\"" + cmd.String() + "\""), nil
}