From 5fa0d47c0d49ac1d1a7f0d3b689ea6e9853bcc9f Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 1 Jan 2025 17:05:43 +0800 Subject: [PATCH] more flexible domain matching --- internal/entrypoint/entrypoint.go | 50 +++-------- internal/entrypoint/entrypoint_test.go | 120 +++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 36 deletions(-) create mode 100644 internal/entrypoint/entrypoint_test.go diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 82012ad..0046a93 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -27,6 +27,8 @@ var ( epAccessLoggerMu sync.Mutex ) +var ErrNoSuchRoute = errors.New("no such route") + func SetFindRouteDomains(domains []string) { if len(domains) == 0 { findRouteFunc = findRouteAnyDomain @@ -73,14 +75,6 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { func Handler(w http.ResponseWriter, r *http.Request) { mux, err := findRouteFunc(r.Host) - if err != nil { - // try find with exact match - r, ok := routes.GetHTTPRoute(r.Host) - if ok { - mux = r - err = nil - } - } if err == nil { if epAccessLogger != nil { epMiddlewareMu.Lock() @@ -126,45 +120,29 @@ func Handler(w http.ResponseWriter, r *http.Request) { func findRouteAnyDomain(host string) (route.HTTPRoute, error) { hostSplit := strutils.SplitRune(host, '.') - n := len(hostSplit) - switch { - case n == 3: - host = hostSplit[0] - case n > 3: - var builder strings.Builder - builder.Grow(2*n - 3) - builder.WriteString(hostSplit[0]) - for _, part := range hostSplit[:n-2] { - builder.WriteRune('.') - builder.WriteString(part) - } - host = builder.String() - default: - return nil, errors.New("missing subdomain in url") - } - if r, ok := routes.GetHTTPRoute(host); ok { + target := hostSplit[0] + + if r, ok := routes.GetHTTPRouteOrExact(target, host); ok { return r, nil } - return nil, fmt.Errorf("no such route: %s", host) + return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, target) } func findRouteByDomains(domains []string) func(host string) (route.HTTPRoute, error) { return func(host string) (route.HTTPRoute, error) { - var subdomain string - for _, domain := range domains { if strings.HasSuffix(host, domain) { - subdomain = strings.TrimSuffix(host, domain) - break + target := strings.TrimSuffix(host, domain) + if r, ok := routes.GetHTTPRoute(target); ok { + return r, nil + } } } - if subdomain != "" { // matched - if r, ok := routes.GetHTTPRoute(subdomain); ok { - return r, nil - } - return nil, fmt.Errorf("no such route: %s", subdomain) + // fallback to exact match + if r, ok := routes.GetHTTPRoute(host); ok { + return r, nil } - return nil, fmt.Errorf("%s does not match any base domain", host) + return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host) } } diff --git a/internal/entrypoint/entrypoint_test.go b/internal/entrypoint/entrypoint_test.go new file mode 100644 index 0000000..65153a4 --- /dev/null +++ b/internal/entrypoint/entrypoint_test.go @@ -0,0 +1,120 @@ +package entrypoint + +import ( + "testing" + + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/routes" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +var r route.HTTPRoute + +func run(t *testing.T, match []string, noMatch []string) { + t.Helper() + t.Cleanup(routes.TestClear) + t.Cleanup(func() { + SetFindRouteDomains(nil) + }) + + for _, test := range match { + t.Run(test, func(t *testing.T) { + found, err := findRouteFunc(test) + ExpectNoError(t, err) + ExpectTrue(t, found == &r) + }) + } + + for _, test := range noMatch { + t.Run(test, func(t *testing.T) { + _, err := findRouteFunc(test) + ExpectError(t, ErrNoSuchRoute, err) + }) + } +} + +func TestFindRouteAnyDomain(t *testing.T) { + routes.SetHTTPRoute("app1", &r) + + tests := []string{ + "app1.com", + "app1.domain.com", + "app1.sub.domain.com", + } + testsNoMatch := []string{ + "sub.app1.com", + "app2.com", + "app2.domain.com", + "app2.sub.domain.com", + } + + run(t, tests, testsNoMatch) +} + +func TestFindRouteExactHostMatch(t *testing.T) { + tests := []string{ + "app2.com", + "app2.domain.com", + "app2.sub.domain.com", + } + testsNoMatch := []string{ + "sub.app2.com", + "app1.com", + "app1.domain.com", + "app1.sub.domain.com", + } + + for _, test := range tests { + routes.SetHTTPRoute(test, &r) + } + + run(t, tests, testsNoMatch) +} + +func TestFindRouteByDomains(t *testing.T) { + SetFindRouteDomains([]string{ + ".domain.com", + ".sub.domain.com", + }) + + routes.SetHTTPRoute("app1", &r) + + tests := []string{ + "app1.domain.com", + "app1.sub.domain.com", + } + testsNoMatch := []string{ + "sub.app1.com", + "app1.com", + "app1.domain.co", + "app1.domain.com.hk", + "app1.sub.domain.co", + "app2.domain.com", + "app2.sub.domain.com", + } + + run(t, tests, testsNoMatch) +} + +func TestFindRouteByDomainsExactMatch(t *testing.T) { + SetFindRouteDomains([]string{ + ".domain.com", + ".sub.domain.com", + }) + + routes.SetHTTPRoute("app1.foo.bar", &r) + + tests := []string{ + "app1.foo.bar", // exact match + "app1.foo.bar.domain.com", + "app1.foo.bar.sub.domain.com", + } + testsNoMatch := []string{ + "sub.app1.foo.bar", + "sub.app1.foo.bar.com", + "app1.domain.com", + "app1.sub.domain.com", + } + + run(t, tests, testsNoMatch) +}