diff --git a/internal/api/handler.go b/internal/api/handler.go index ad47f2d..f3a9162 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -9,6 +9,8 @@ import ( "github.com/yusing/go-proxy/internal/api/v1/auth" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/net/http/middleware" ) type ServeMux struct{ *http.ServeMux } @@ -18,7 +20,7 @@ func NewServeMux() ServeMux { } func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { - mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler)) + mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(rateLimited(handler))) } func NewHandler() http.Handler { @@ -56,3 +58,16 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc { f(w, r) } } + +func rateLimited(f http.HandlerFunc) http.HandlerFunc { + m, err := middleware.RateLimiter.WithOptionsClone(middleware.OptionsRaw{ + "average": 10, + "burst": 10, + }) + if err != nil { + logging.Fatal().Err(err).Msg("unable to create API rate limiter") + } + return func(w http.ResponseWriter, r *http.Request) { + m.ModifyRequest(f, w, r) + } +} diff --git a/internal/config/query.go b/internal/config/query.go index 99131cb..879204c 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -89,7 +89,7 @@ func HomepageConfig() homepage.Config { if item.URL == "" { if len(domains) > 0 { - item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port) + item.URL = fmt.Sprintf("%s://%s%s:%s", proto, strings.ToLower(alias), domains[0], port) } } item.AltURL = r.TargetURL().String()