From f2df756c175d7a568d6650a54062db595b104533 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 11 Jan 2025 02:14:22 +0800 Subject: [PATCH] fix rule parser --- internal/route/rules/parser.go | 55 +++++++++++------ internal/route/rules/parser_test.go | 96 +++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 19 deletions(-) create mode 100644 internal/route/rules/parser_test.go diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go index ac51ebf..653c150 100644 --- a/internal/route/rules/parser.go +++ b/internal/route/rules/parser.go @@ -1,7 +1,8 @@ package rules import ( - "strings" + "bytes" + "unicode" E "github.com/yusing/go-proxy/internal/error" ) @@ -22,15 +23,30 @@ var escapedChars = map[rune]rune{ // error 403 "Forbidden 'foo' 'bar'" // error 403 Forbidden\ \"foo\"\ \"bar\". func parse(v string) (subject string, args []string, err E.Error) { - v = strings.TrimSpace(v) - var buf strings.Builder + buf := bytes.NewBuffer(make([]byte, 0, len(v))) + escaped := false - quotes := make([]rune, 0, 4) - flush := func() { + quote := rune(0) + flush := func(quoted bool) { + part := buf.String() + if !quoted { + beg := 0 + for i, r := range part { + if unicode.IsSpace(r) { + beg = i + 1 + } else { + break + } + } + if beg == len(part) { // all spaces + return + } + part = part[beg:] // trim leading spaces + } if subject == "" { - subject = buf.String() + subject = part } else { - args = append(args, buf.String()) + args = append(args, part) } buf.Reset() } @@ -51,29 +67,30 @@ func parse(v string) (subject string, args []string, err E.Error) { continue case '"', '\'': switch { - case len(quotes) > 0 && quotes[len(quotes)-1] == r: - quotes = quotes[:len(quotes)-1] - if len(quotes) == 0 { - flush() - } else { - buf.WriteRune(r) - } - case len(quotes) == 0: - quotes = append(quotes, r) + case quote == 0: + quote = r + flush(false) + case r == quote: + quote = 0 + flush(true) default: buf.WriteRune(r) } case ' ': - flush() + if quote == 0 { + flush(false) + continue + } + fallthrough default: buf.WriteRune(r) } } - if len(quotes) > 0 { + if quote != 0 { err = ErrUnterminatedQuotes } else { - flush() + flush(false) } return } diff --git a/internal/route/rules/parser_test.go b/internal/route/rules/parser_test.go new file mode 100644 index 0000000..b560ef4 --- /dev/null +++ b/internal/route/rules/parser_test.go @@ -0,0 +1,96 @@ +package rules + +import ( + "strconv" + "testing" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestParser(t *testing.T) { + tests := []struct { + name string + input string + subject string + args []string + wantErr E.Error + }{ + { + name: "basic", + input: "rewrite / /foo/bar", + subject: "rewrite", + args: []string{"/", "/foo/bar"}, + }, + { + name: "with quotes", + input: `error 403 "Forbidden 'foo' 'bar'."`, + subject: "error", + args: []string{"403", "Forbidden 'foo' 'bar'."}, + }, + { + name: "with quotes 2", + input: `basic_auth "username" "password"`, + subject: "basic_auth", + args: []string{"username", "password"}, + }, + { + name: "with escaped", + input: `foo bar\ baz bar\r\n\tbaz bar\'\"baz`, + subject: "foo", + args: []string{"bar baz", "bar\r\n\tbaz", `bar'"baz`}, + }, + { + name: "empty string", + input: `foo '' ""`, + subject: "foo", + args: []string{"", ""}, + }, + { + name: "invalid_escape", + input: `foo \bar`, + wantErr: ErrUnsupportedEscapeChar, + }, + { + name: "chaos", + input: `error 403 "Forbidden "foo" "bar""`, + subject: "error", + args: []string{"403", "Forbidden ", "foo", " ", "bar", ""}, + }, + { + name: "chaos2", + input: `foo "'bar' 'baz'" abc\ 'foo "bar"'.`, + subject: "foo", + args: []string{"'bar' 'baz'", "abc ", `foo "bar"`, "."}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + subject, args, err := parse(tt.input) + if tt.wantErr != nil { + ExpectError(t, tt.wantErr, err) + return + } + // t.Log(subject, args, err) + ExpectNoError(t, err) + ExpectEqual(t, subject, tt.subject) + ExpectEqual(t, len(args), len(tt.args)) + for i, arg := range args { + ExpectEqual(t, arg, tt.args[i]) + } + }) + } + t.Run("unterminated quotes", func(t *testing.T) { + tests := []string{ + `error 403 "Forbidden 'foo' 'bar'`, + `error 403 "Forbidden 'foo 'bar'`, + `error 403 "Forbidden foo "bar'"`, + } + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + _, _, err := parse(test) + ExpectError(t, ErrUnterminatedQuotes, err) + }) + } + }) +}