diff --git a/internal/net/http/methods.go b/internal/net/http/methods.go index a46923d..caca564 100644 --- a/internal/net/http/methods.go +++ b/internal/net/http/methods.go @@ -2,19 +2,19 @@ package http import "net/http" -var validMethods = map[string]struct{}{ - http.MethodGet: {}, - http.MethodHead: {}, - http.MethodPost: {}, - http.MethodPut: {}, - http.MethodPatch: {}, - http.MethodDelete: {}, - http.MethodConnect: {}, - http.MethodOptions: {}, - http.MethodTrace: {}, -} - func IsMethodValid(method string) bool { - _, ok := validMethods[method] - return ok + switch method { + case http.MethodGet, + http.MethodHead, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + http.MethodConnect, + http.MethodOptions, + http.MethodTrace: + return true + default: + return false + } } diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index 6a9e852..4bd3151 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -11,7 +11,8 @@ var ( ErrInvalidCommandSequence = E.New("invalid command sequence") ErrInvalidSetTarget = E.New("invalid `rule.set` target") - ErrExpectNoArg = ErrInvalidArguments.Withf("expect no arg") - ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") - ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") + ErrExpectNoArg = E.New("expect no arg") + ErrExpectOneArg = E.New("expect 1 arg") + ErrExpectTwoArgs = E.New("expect 2 args") + ErrExpectKVOptionalV = E.New("expect 'key' or 'key value'") ) diff --git a/internal/route/rules/help.go b/internal/route/rules/help.go index cdff5c9..9b222b7 100644 --- a/internal/route/rules/help.go +++ b/internal/route/rules/help.go @@ -20,9 +20,8 @@ func (h *Help) String() string { sb.WriteString(h.command) sb.WriteString(" ") for arg := range h.args { - sb.WriteRune('<') - sb.WriteString(arg) - sb.WriteString("> ") + sb.WriteString(strings.ToUpper(arg)) + sb.WriteRune(' ') } if h.description != "" { sb.WriteString("\n\t") @@ -32,7 +31,7 @@ func (h *Help) String() string { sb.WriteRune('\n') for arg, desc := range h.args { sb.WriteRune('\t') - sb.WriteString(arg) + sb.WriteString(strings.ToUpper(arg)) sb.WriteString(": ") sb.WriteString(desc) sb.WriteRune('\n') diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index f2c64c8..e0b1837 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -34,15 +34,25 @@ var checkers = map[string]struct { help: Help{ command: OnHeader, args: map[string]string{ - "key": "the header key", - "value": "the header value", + "key": "the header key", + "[value]": "the header value", }, }, - validate: toStrTuple, + validate: toKVOptionalV, builder: func(args any) CheckFunc { k, v := args.(*StrTuple).Unpack() + if v == "" { + return func(cached Cache, r *http.Request) bool { + return len(r.Header[k]) > 0 + } + } return func(cached Cache, r *http.Request) bool { - return r.Header.Get(k) == v + for _, vv := range r.Header[k] { + if v == vv { + return true + } + } + return false } }, }, @@ -50,13 +60,18 @@ var checkers = map[string]struct { help: Help{ command: OnQuery, args: map[string]string{ - "key": "the query key", - "value": "the query value", + "key": "the query key", + "[value]": "the query value", }, }, - validate: toStrTuple, + validate: toKVOptionalV, builder: func(args any) CheckFunc { k, v := args.(*StrTuple).Unpack() + if v == "" { + return func(cached Cache, r *http.Request) bool { + return len(cached.GetQueries(r)[k]) > 0 + } + } return func(cached Cache, r *http.Request) bool { queries := cached.GetQueries(r)[k] for _, query := range queries { @@ -72,13 +87,24 @@ var checkers = map[string]struct { help: Help{ command: OnCookie, args: map[string]string{ - "key": "the cookie key", - "value": "the cookie value", + "key": "the cookie key", + "[value]": "the cookie value", }, }, - validate: toStrTuple, + validate: toKVOptionalV, builder: func(args any) CheckFunc { k, v := args.(*StrTuple).Unpack() + if v == "" { + return func(cached Cache, r *http.Request) bool { + cookies := cached.GetCookies(r) + for _, cookie := range cookies { + if cookie.Name == k { + return true + } + } + return false + } + } return func(cached Cache, r *http.Request) bool { cookies := cached.GetCookies(r) for _, cookie := range cookies { @@ -95,13 +121,18 @@ var checkers = map[string]struct { help: Help{ command: OnForm, args: map[string]string{ - "key": "the form key", - "value": "the form value", + "key": "the form key", + "[value]": "the form value", }, }, - validate: toStrTuple, + validate: toKVOptionalV, builder: func(args any) CheckFunc { k, v := args.(*StrTuple).Unpack() + if v == "" { + return func(cached Cache, r *http.Request) bool { + return r.FormValue(k) != "" + } + } return func(cached Cache, r *http.Request) bool { return r.FormValue(k) == v } @@ -111,13 +142,18 @@ var checkers = map[string]struct { help: Help{ command: OnPostForm, args: map[string]string{ - "key": "the form key", - "value": "the form value", + "key": "the form key", + "[value]": "the form value", }, }, - validate: toStrTuple, + validate: toKVOptionalV, builder: func(args any) CheckFunc { k, v := args.(*StrTuple).Unpack() + if v == "" { + return func(cached Cache, r *http.Request) bool { + return r.PostFormValue(k) != "" + } + } return func(cached Cache, r *http.Request) bool { return r.PostFormValue(k) == v } diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index acbe384..c5bdc8e 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -15,25 +15,50 @@ func TestParseOn(t *testing.T) { }{ // header { - name: "header_valid", + name: "header_valid_kv", input: "header Connection Upgrade", wantErr: nil, }, { - name: "header_invalid", + name: "header_valid_k", input: "header Connection", - wantErr: ErrInvalidArguments, + wantErr: nil, + }, + { + name: "header_missing_arg", + input: "header", + wantErr: ErrExpectKVOptionalV, }, // query { - name: "query_valid", + name: "query_valid_kv", input: "query key value", wantErr: nil, }, { - name: "query_invalid", + name: "query_valid_k", input: "query key", - wantErr: ErrInvalidArguments, + wantErr: nil, + }, + { + name: "query_missing_arg", + input: "query", + wantErr: ErrExpectKVOptionalV, + }, + { + name: "cookie_valid_kv", + input: "cookie key value", + wantErr: nil, + }, + { + name: "cookie_valid_k", + input: "cookie key", + wantErr: nil, + }, + { + name: "cookie_missing_arg", + input: "cookie", + wantErr: ErrExpectKVOptionalV, }, // method { @@ -43,9 +68,14 @@ func TestParseOn(t *testing.T) { }, { name: "method_invalid", - input: "method", + input: "method invalid", wantErr: ErrInvalidArguments, }, + { + name: "method_missing_arg", + input: "method", + wantErr: ErrExpectOneArg, + }, // path { name: "path_valid", @@ -53,9 +83,9 @@ func TestParseOn(t *testing.T) { wantErr: nil, }, { - name: "path_invalid", + name: "path_missing_arg", input: "path", - wantErr: ErrInvalidArguments, + wantErr: ErrExpectOneArg, }, // remote { @@ -65,9 +95,14 @@ func TestParseOn(t *testing.T) { }, { name: "remote_invalid", - input: "remote", + input: "remote abcd", wantErr: ErrInvalidArguments, }, + { + name: "remote_missing_arg", + input: "remote", + wantErr: ErrExpectOneArg, + }, { name: "unknown_target", input: "unknown", diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index 818ce96..e894309 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -36,6 +36,18 @@ func toStrTuple(args []string) (any, E.Error) { return &StrTuple{args[0], args[1]}, nil } +// toKVOptionalV returns *StrTuple that value is optional. +func toKVOptionalV(args []string) (any, E.Error) { + switch len(args) { + case 1: + return &StrTuple{args[0], ""}, nil + case 2: + return &StrTuple{args[0], args[1]}, nil + default: + return nil, ErrExpectKVOptionalV + } +} + // validateURL returns types.URL with the URL validated. func validateURL(args []string) (any, E.Error) { if len(args) != 1 {