more flexible domain matching

This commit is contained in:
yusing 2025-01-01 17:05:43 +08:00
parent 659ad29875
commit 5fa0d47c0d
2 changed files with 134 additions and 36 deletions

View file

@ -27,6 +27,8 @@ var (
epAccessLoggerMu sync.Mutex epAccessLoggerMu sync.Mutex
) )
var ErrNoSuchRoute = errors.New("no such route")
func SetFindRouteDomains(domains []string) { func SetFindRouteDomains(domains []string) {
if len(domains) == 0 { if len(domains) == 0 {
findRouteFunc = findRouteAnyDomain findRouteFunc = findRouteAnyDomain
@ -73,14 +75,6 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
func Handler(w http.ResponseWriter, r *http.Request) { func Handler(w http.ResponseWriter, r *http.Request) {
mux, err := findRouteFunc(r.Host) mux, err := findRouteFunc(r.Host)
if err != nil {
// try find with exact match
r, ok := routes.GetHTTPRoute(r.Host)
if ok {
mux = r
err = nil
}
}
if err == nil { if err == nil {
if epAccessLogger != nil { if epAccessLogger != nil {
epMiddlewareMu.Lock() epMiddlewareMu.Lock()
@ -126,45 +120,29 @@ func Handler(w http.ResponseWriter, r *http.Request) {
func findRouteAnyDomain(host string) (route.HTTPRoute, error) { func findRouteAnyDomain(host string) (route.HTTPRoute, error) {
hostSplit := strutils.SplitRune(host, '.') hostSplit := strutils.SplitRune(host, '.')
n := len(hostSplit) target := hostSplit[0]
switch {
case n == 3: if r, ok := routes.GetHTTPRouteOrExact(target, host); ok {
host = hostSplit[0]
case n > 3:
var builder strings.Builder
builder.Grow(2*n - 3)
builder.WriteString(hostSplit[0])
for _, part := range hostSplit[:n-2] {
builder.WriteRune('.')
builder.WriteString(part)
}
host = builder.String()
default:
return nil, errors.New("missing subdomain in url")
}
if r, ok := routes.GetHTTPRoute(host); ok {
return r, nil return r, nil
} }
return nil, fmt.Errorf("no such route: %s", host) return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, target)
} }
func findRouteByDomains(domains []string) func(host string) (route.HTTPRoute, error) { func findRouteByDomains(domains []string) func(host string) (route.HTTPRoute, error) {
return func(host string) (route.HTTPRoute, error) { return func(host string) (route.HTTPRoute, error) {
var subdomain string
for _, domain := range domains { for _, domain := range domains {
if strings.HasSuffix(host, domain) { if strings.HasSuffix(host, domain) {
subdomain = strings.TrimSuffix(host, domain) target := strings.TrimSuffix(host, domain)
break if r, ok := routes.GetHTTPRoute(target); ok {
return r, nil
}
} }
} }
if subdomain != "" { // matched // fallback to exact match
if r, ok := routes.GetHTTPRoute(subdomain); ok { if r, ok := routes.GetHTTPRoute(host); ok {
return r, nil return r, nil
} }
return nil, fmt.Errorf("no such route: %s", subdomain) return nil, fmt.Errorf("%w: %s", ErrNoSuchRoute, host)
}
return nil, fmt.Errorf("%s does not match any base domain", host)
} }
} }

View file

@ -0,0 +1,120 @@
package entrypoint
import (
"testing"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/routes"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var r route.HTTPRoute
func run(t *testing.T, match []string, noMatch []string) {
t.Helper()
t.Cleanup(routes.TestClear)
t.Cleanup(func() {
SetFindRouteDomains(nil)
})
for _, test := range match {
t.Run(test, func(t *testing.T) {
found, err := findRouteFunc(test)
ExpectNoError(t, err)
ExpectTrue(t, found == &r)
})
}
for _, test := range noMatch {
t.Run(test, func(t *testing.T) {
_, err := findRouteFunc(test)
ExpectError(t, ErrNoSuchRoute, err)
})
}
}
func TestFindRouteAnyDomain(t *testing.T) {
routes.SetHTTPRoute("app1", &r)
tests := []string{
"app1.com",
"app1.domain.com",
"app1.sub.domain.com",
}
testsNoMatch := []string{
"sub.app1.com",
"app2.com",
"app2.domain.com",
"app2.sub.domain.com",
}
run(t, tests, testsNoMatch)
}
func TestFindRouteExactHostMatch(t *testing.T) {
tests := []string{
"app2.com",
"app2.domain.com",
"app2.sub.domain.com",
}
testsNoMatch := []string{
"sub.app2.com",
"app1.com",
"app1.domain.com",
"app1.sub.domain.com",
}
for _, test := range tests {
routes.SetHTTPRoute(test, &r)
}
run(t, tests, testsNoMatch)
}
func TestFindRouteByDomains(t *testing.T) {
SetFindRouteDomains([]string{
".domain.com",
".sub.domain.com",
})
routes.SetHTTPRoute("app1", &r)
tests := []string{
"app1.domain.com",
"app1.sub.domain.com",
}
testsNoMatch := []string{
"sub.app1.com",
"app1.com",
"app1.domain.co",
"app1.domain.com.hk",
"app1.sub.domain.co",
"app2.domain.com",
"app2.sub.domain.com",
}
run(t, tests, testsNoMatch)
}
func TestFindRouteByDomainsExactMatch(t *testing.T) {
SetFindRouteDomains([]string{
".domain.com",
".sub.domain.com",
})
routes.SetHTTPRoute("app1.foo.bar", &r)
tests := []string{
"app1.foo.bar", // exact match
"app1.foo.bar.domain.com",
"app1.foo.bar.sub.domain.com",
}
testsNoMatch := []string{
"sub.app1.foo.bar",
"sub.app1.foo.bar.com",
"app1.domain.com",
"app1.sub.domain.com",
}
run(t, tests, testsNoMatch)
}