package rules_test import ( "encoding/base64" "fmt" "net/http" "net/url" "testing" "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/routes" . "github.com/yusing/go-proxy/internal/route/rules" expect "github.com/yusing/go-proxy/internal/utils/testing" "golang.org/x/crypto/bcrypt" ) type testCorrectness struct { name string checker string input *http.Request want bool } func genCorrectnessTestCases(field string, genRequest func(k, v string) *http.Request) []testCorrectness { return []testCorrectness{ { name: field + "_match", checker: field + " foo bar", input: genRequest("foo", "bar"), want: true, }, { name: field + "_no_match", checker: field + " foo baz", input: genRequest("foo", "bar"), want: false, }, { name: field + "_exists", checker: field + " foo", input: genRequest("foo", "abcd"), want: true, }, { name: field + "_not_exists", checker: field + " foo", input: genRequest("bar", "abcd"), want: false, }, } } func TestOnCorrectness(t *testing.T) { tests := []testCorrectness{ { name: "method_match", checker: "method GET", input: &http.Request{Method: http.MethodGet}, want: true, }, { name: "method_no_match", checker: "method GET", input: &http.Request{Method: http.MethodPost}, want: false, }, { name: "path_exact_match", checker: "path /example", input: &http.Request{ URL: &url.URL{Path: "/example"}, }, want: true, }, { name: "path_wildcard_match", checker: "path /example/*", input: &http.Request{ URL: &url.URL{Path: "/example/123"}, }, want: true, }, { name: "remote_match", checker: "remote 192.168.1.0/24", input: &http.Request{ RemoteAddr: "192.168.1.5", }, want: true, }, { name: "remote_no_match", checker: "remote 192.168.1.0/24", input: &http.Request{ RemoteAddr: "192.168.2.5", }, want: false, }, { name: "basic_auth_correct", checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), input: &http.Request{ Header: http.Header{ "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password" }, }, want: true, }, { name: "basic_auth_incorrect", checker: "basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), input: &http.Request{ Header: http.Header{ "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong" }, }, want: false, }, { name: "route_match", checker: "route example", input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{ Alias: "example", Root: "/", }))), want: true, }, { name: "route_no_match", checker: "route example", input: &http.Request{ Header: http.Header{}, }, want: false, }, } tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request { return &http.Request{ Header: http.Header{k: []string{v}}, } })...) tests = append(tests, genCorrectnessTestCases("query", func(k, v string) *http.Request { return &http.Request{ URL: &url.URL{ RawQuery: fmt.Sprintf("%s=%s", k, v), }, } })...) tests = append(tests, genCorrectnessTestCases("cookie", func(k, v string) *http.Request { return &http.Request{ Header: http.Header{ "Cookie": {fmt.Sprintf("%s=%s", k, v)}, }, } })...) tests = append(tests, genCorrectnessTestCases("form", func(k, v string) *http.Request { return &http.Request{ Form: url.Values{ k: []string{v}, }, } })...) tests = append(tests, genCorrectnessTestCases("postform", func(k, v string) *http.Request { return &http.Request{ PostForm: url.Values{ k: []string{v}, }, } })...) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var on RuleOn err := on.Parse(tt.checker) expect.NoError(t, err) got := on.Check(Cache{}, tt.input) expect.Equal(t, tt.want, got) }) } }