diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 2719443..66cd116 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -23,10 +23,10 @@ lint: enabled: - hadolint@2.12.1-beta - actionlint@1.7.6 - - checkov@3.2.347 + - checkov@3.2.350 - git-diff-check - gofmt@1.20.4 - - golangci-lint@1.62.2 + - golangci-lint@1.63.4 - osv-scanner@1.9.2 - oxipng@9.1.3 - prettier@3.4.2 diff --git a/README.md b/README.md index ad4f742..b4751ed 100755 --- a/README.md +++ b/README.md @@ -87,8 +87,10 @@ Setup DNS Records point to machine which runs `GoDoxy`, e.g. - change username and password for WebUI authentication ```shell - sed -i "s|API_USERNAME=.*|API_USERNAME=admin|g" .env - sed -i "s|API_PASSWORD=.*|API_PASSWORD=some-strong-password|g" .env + USERNAME=admin + PASSWORD=some-password + sed -i "s|API_USERNAME=.*|API_USERNAME=${USERNAME}|g" .env + sed -i "s|API_PASSWORD=.*|API_PASSWORD=${PASSWORD}|g" .env ``` 4. _(Optional)_ setup `docker-socket-proxy` other docker nodes (see [Multi docker nodes setup](https://github.com/yusing/go-proxy/wiki/Configurations#multi-docker-nodes-setup)) then add them inside `config.yml` diff --git a/README_CHT.md b/README_CHT.md index 9673dd5..702e7e0 100644 --- a/README_CHT.md +++ b/README_CHT.md @@ -87,8 +87,10 @@ _加入我們的 [Discord](https://discord.gg/umReR62nRd) 獲取幫助和討論_ - 更改網頁介面認證的使用者名稱和密碼 ```shell - sed -i "s|API_USERNAME=.*|API_USERNAME=admin|g" .env - sed -i "s|API_PASSWORD=.*|API_PASSWORD=some-strong-password|g" .env + USERNAME=admin + PASSWORD=some-password + sed -i "s|API_USERNAME=.*|API_USERNAME=${USERNAME}|g" .env + sed -i "s|API_PASSWORD=.*|API_PASSWORD=${PASSWORD}|g" .env ``` 4. _(可選)_ 設置其他 Docker 節點的 `docker-socket-proxy`(參見 [多 Docker 節點設置](https://github.com/yusing/go-proxy/wiki/Configurations#multi-docker-nodes-setup)),然後在 `config.yml` 中添加它們 diff --git a/cmd/main.go b/cmd/main.go index 4c60fbd..e1242ec 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,24 +3,19 @@ package main import ( "encoding/json" "log" - "net/http" "os" "os/signal" "syscall" "time" "github.com/yusing/go-proxy/internal" - "github.com/yusing/go-proxy/internal/api" - "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" - "github.com/yusing/go-proxy/internal/entrypoint" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/metrics" "github.com/yusing/go-proxy/internal/net/http/middleware" - "github.com/yusing/go-proxy/internal/net/http/server" + "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/pkg" ) @@ -98,16 +93,16 @@ func main() { switch args.Command { case common.CommandListRoutes: cfg.StartProxyProviders() - printJSON(config.RoutesByAlias()) + printJSON(routes.RoutesByAlias()) return case common.CommandListConfigs: - printJSON(config.Value()) + printJSON(cfg.Value()) return case common.CommandDebugListEntries: - printJSON(config.DumpEntries()) + printJSON(cfg.DumpEntries()) return case common.CommandDebugListProviders: - printJSON(config.DumpProviders()) + printJSON(cfg.DumpProviders()) return } @@ -115,58 +110,25 @@ func main() { logging.Warn().Msg("API JWT secret is empty, authentication is disabled") } - cfg.StartProxyProviders() + cfg.Start() config.WatchChanges() - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT) - signal.Notify(sig, syscall.SIGTERM) - signal.Notify(sig, syscall.SIGHUP) - - autocert := config.GetAutoCertProvider() - if autocert != nil { - if err := autocert.Setup(); err != nil { - E.LogFatal("autocert setup error", err) - } - } else { - logging.Info().Msg("autocert not configured") - } - - server.StartServer(server.Options{ - Name: "proxy", - CertProvider: autocert, - HTTPAddr: common.ProxyHTTPAddr, - HTTPSAddr: common.ProxyHTTPSAddr, - Handler: http.HandlerFunc(entrypoint.Handler), - }) - // Initialize authentication providers if err := auth.Initialize(); err != nil { logging.Warn().Err(err).Msg("Failed to initialize authentication providers") } - server.StartServer(server.Options{ - Name: "api", - CertProvider: autocert, - HTTPAddr: common.APIHTTPAddr, - Handler: api.NewHandler(), - }) - - if common.PrometheusEnabled { - server.StartServer(server.Options{ - Name: "metrics", - CertProvider: autocert, - HTTPAddr: common.MetricsHTTPAddr, - Handler: metrics.NewHandler(), - }) - } + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT) + signal.Notify(sig, syscall.SIGTERM) + signal.Notify(sig, syscall.SIGHUP) // wait for signal <-sig - // gracefully shutdown + // grafully shutdown logging.Info().Msg("shutting down") - _ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown)) + _ = task.GracefulShutdown(time.Second * time.Duration(cfg.Value().TimeoutShutdown)) } func prepareDirectory(dir string) { diff --git a/go.mod b/go.mod index a20c6c8..b5af539 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( go.opentelemetry.io/otel/trace v1.33.0 // indirect golang.org/x/crypto v0.32.0 // indirect golang.org/x/mod v0.22.0 // indirect + golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/tools v0.29.0 // indirect diff --git a/internal/api/handler.go b/internal/api/handler.go index 76a3dc7..6d1f80a 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -8,20 +8,17 @@ import ( "github.com/yusing/go-proxy/internal/api/v1/auth" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" + config "github.com/yusing/go-proxy/internal/config/types" ) type ServeMux struct{ *http.ServeMux } -func NewServeMux() ServeMux { - return ServeMux{http.NewServeMux()} -} - func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler)) } -func NewHandler() http.Handler { - mux := NewServeMux() +func NewHandler(cfg config.ConfigInstance) http.Handler { + mux := ServeMux{http.NewServeMux()} mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("POST", "/v1/login", auth.LoginHandler) @@ -30,19 +27,25 @@ func NewHandler() http.Handler { mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) - mux.HandleFunc("POST", "/v1/reload", v1.Reload) - mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(v1.List)) - mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(v1.List)) - mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(v1.List)) + mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) + mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) + mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) + mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile) - mux.HandleFunc("GET", "/v1/stats", v1.Stats) - mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) + mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) + mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) return mux } +func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(cfg, w, r) + } +} + // allow only requests to API server with localhost. func checkHost(f http.HandlerFunc) http.HandlerFunc { if common.IsDebug { diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index bc8720b..63f3aa3 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -9,7 +9,7 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/route/provider" diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 85d41a5..52617b4 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -6,9 +6,10 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/net/http/middleware" - "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/routes" + route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" ) @@ -24,7 +25,7 @@ const ( ListTasks = "tasks" ) -func List(w http.ResponseWriter, r *http.Request) { +func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { what := r.PathValue("what") if what == "" { what = ListRoutes @@ -40,7 +41,7 @@ func List(w http.ResponseWriter, r *http.Request) { U.RespondJSON(w, r, route) } case ListRoutes: - U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type")))) + U.RespondJSON(w, r, routes.RoutesByAlias(route.RouteType(r.FormValue("type")))) case ListFiles: listFiles(w, r) case ListMiddlewares: @@ -48,9 +49,9 @@ func List(w http.ResponseWriter, r *http.Request) { case ListMiddlewareTraces: U.RespondJSON(w, r, middleware.GetAllTrace()) case ListMatchDomains: - U.RespondJSON(w, r, config.Value().MatchDomains) + U.RespondJSON(w, r, cfg.Value().MatchDomains) case ListHomepageConfig: - U.RespondJSON(w, r, config.HomepageConfig()) + U.RespondJSON(w, r, routes.HomepageConfig(cfg.Value().Homepage.UseDefaultCategories)) case ListTasks: U.RespondJSON(w, r, task.DebugTaskList()) default: @@ -60,9 +61,9 @@ func List(w http.ResponseWriter, r *http.Request) { func listRoute(which string) any { if which == "" || which == "all" { - return config.RoutesByAlias() + return routes.RoutesByAlias() } - routes := config.RoutesByAlias() + routes := routes.RoutesByAlias() route, ok := routes[which] if !ok { return nil diff --git a/internal/api/v1/reload.go b/internal/api/v1/reload.go index 42a4198..defa4e4 100644 --- a/internal/api/v1/reload.go +++ b/internal/api/v1/reload.go @@ -4,11 +4,11 @@ import ( "net/http" U "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" ) -func Reload(w http.ResponseWriter, r *http.Request) { - if err := config.Reload(); err != nil { +func Reload(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { + if err := cfg.Reload(); err != nil { U.HandleErr(w, r, err) return } diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index 99ba304..0d9617b 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -9,25 +9,25 @@ import ( "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) -func Stats(w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, getStats()) +func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { + U.RespondJSON(w, r, getStats(cfg)) } -func StatsWS(w http.ResponseWriter, r *http.Request) { +func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { var originPats []string localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - if len(config.Value().MatchDomains) == 0 { + if len(cfg.Value().MatchDomains) == 0 { U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins") originPats = []string{"*"} } else { - originPats = make([]string, len(config.Value().MatchDomains)) - for i, domain := range config.Value().MatchDomains { + originPats = make([]string, len(cfg.Value().MatchDomains)) + for i, domain := range cfg.Value().MatchDomains { originPats[i] = "*" + domain } originPats = append(originPats, localAddresses...) @@ -52,7 +52,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { defer ticker.Stop() for range ticker.C { - stats := getStats() + stats := getStats(cfg) if err := wsjson.Write(ctx, conn, stats); err != nil { U.LogError(r).Msg("failed to write JSON") return @@ -62,9 +62,9 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { var startTime = time.Now() -func getStats() map[string]any { +func getStats(cfg config.ConfigInstance) map[string]any { return map[string]any{ - "proxies": config.Statistics(), + "proxies": cfg.Statistics(), "uptime": strutils.FormatDuration(time.Since(startTime)), } } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 2c45aa2..deb010e 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -7,6 +7,7 @@ import ( "os" "path" "reflect" + "runtime" "sort" "time" @@ -148,28 +149,40 @@ func (p *Provider) ShouldRenewOn() time.Time { panic("no certificate available") } -func (p *Provider) ScheduleRenewal() { +func (p *Provider) ScheduleRenewal(parent task.Parent) { if p.GetName() == ProviderLocal { return } go func() { - task := task.RootTask("cert-renew-scheduler", true) + lastErrOn := time.Time{} + renewalTime := p.ShouldRenewOn() + timer := time.NewTimer(time.Until(renewalTime)) + defer timer.Stop() + + task := parent.Subtask("cert-renew-scheduler") defer task.Finish(nil) for { - renewalTime := p.ShouldRenewOn() - timer := time.NewTimer(time.Until(renewalTime)) - select { case <-task.Context().Done(): - timer.Stop() return case <-timer.C: + // Retry after 1 hour on failure + if time.Now().Before(lastErrOn.Add(time.Hour)) { + continue + } if err := p.renewIfNeeded(); err != nil { E.LogWarn("cert renew failed", err, &logger) - // Retry after 1 hour on failure - time.Sleep(time.Hour) + lastErrOn = time.Now() + continue } + // Reset on success + lastErrOn = time.Time{} + renewalTime = p.ShouldRenewOn() + timer.Reset(time.Until(renewalTime)) + default: + // Allow other tasks to run + runtime.Gosched() } } }() diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 399cb75..5f172da 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -18,8 +18,6 @@ func (p *Provider) Setup() (err E.Error) { } } - p.ScheduleRenewal() - for _, expiry := range p.GetExpiries() { logger.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) break diff --git a/internal/common/env.go b/internal/common/env.go index 9e23ec5..ae7e102 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -11,6 +11,105 @@ import ( "github.com/rs/zerolog/log" ) +var ( + prefixes = []string{"GODOXY_", "GOPROXY_", ""} + + IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test") + IsDebug = GetEnvBool("DEBUG", IsTest) + IsDebugSkipAuth = GetEnvBool("DEBUG_SKIP_AUTH", false) + IsTrace = GetEnvBool("TRACE", false) && IsDebug + IsProduction = !IsTest && !IsDebug + + ProxyHTTPAddr, + ProxyHTTPHost, + ProxyHTTPPort, + ProxyHTTPURL = GetAddrEnv("HTTP_ADDR", ":80", "http") + + ProxyHTTPSAddr, + ProxyHTTPSHost, + ProxyHTTPSPort, + ProxyHTTPSURL = GetAddrEnv("HTTPS_ADDR", ":443", "https") + + APIHTTPAddr, + APIHTTPHost, + APIHTTPPort, + APIHTTPURL = GetAddrEnv("API_ADDR", "127.0.0.1:8888", "http") + + MetricsHTTPAddr, + MetricsHTTPHost, + MetricsHTTPPort, + MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http") + PrometheusEnabled = MetricsHTTPURL != "" + + APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) + APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) + APIUser = GetEnvString("API_USER", "admin") + APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password")) +) + +func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T { + var value string + var ok bool + for _, prefix := range prefixes { + value, ok = os.LookupEnv(prefix + key) + if ok && value != "" { + break + } + } + if !ok || value == "" { + return defaultValue + } + parsed, err := parser(value) + if err == nil { + return parsed + } + log.Fatal().Err(err).Msgf("env %s: invalid %T value: %s", key, parsed, value) + return defaultValue +} + +func GetEnvString(key string, defaultValue string) string { + return GetEnv(key, defaultValue, func(s string) (string, error) { + return s, nil + }) +} + +func GetEnvBool(key string, defaultValue bool) bool { + return GetEnv(key, defaultValue, strconv.ParseBool) +} + +func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL string) { + addr = GetEnvString(key, defaultValue) + if addr == "" { + return + } + host, port, err := net.SplitHostPort(addr) + if err != nil { + log.Fatal().Msgf("env %s: invalid address: %s", key, addr) + } + if host == "" { + host = "localhost" + } + fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port) + return +} + +func GetDurationEnv(key string, defaultValue time.Duration) time.Duration { + return GetEnv(key, defaultValue, time.ParseDuration) +} + +package common + +import ( + "fmt" + "net" + "os" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog/log" +) + var ( prefixes = []string{"GODOXY_", "GOPROXY_", ""} diff --git a/internal/config/config.go b/internal/config/config.go index f1c42b7..da9b89e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,12 +7,15 @@ import ( "sync" "time" + "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/entrypoint" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/metrics" + "github.com/yusing/go-proxy/internal/net/http/server" "github.com/yusing/go-proxy/internal/notif" proxy "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/task" @@ -26,7 +29,9 @@ type Config struct { value *types.Config providers F.Map[string, *proxy.Provider] autocertProvider *autocert.Provider - task *task.Task + entrypoint *entrypoint.Entrypoint + + task *task.Task } var ( @@ -45,15 +50,18 @@ Make sure you rename it back before next time you start.` You may run "ls-config" to show or dump the current config.` ) +var Validate = types.Validate + func GetInstance() *Config { return instance } func newConfig() *Config { return &Config{ - value: types.DefaultConfig(), - providers: F.NewMapOf[string, *proxy.Provider](), - task: task.RootTask("config", false), + value: types.DefaultConfig(), + providers: F.NewMapOf[string, *proxy.Provider](), + entrypoint: entrypoint.NewEntrypoint(), + task: task.RootTask("config", false), } } @@ -66,11 +74,6 @@ func Load() (*Config, E.Error) { return instance, instance.load() } -func Validate(data []byte) E.Error { - var model types.Config - return utils.DeserializeYAML(data, &model) -} - func MatchDomains() []string { return instance.value.MatchDomains } @@ -101,6 +104,7 @@ func OnConfigChange(ev []events.Event) { } if err := Reload(); err != nil { + logger.Warn().Msg("using last config") // recovered in event queue panic(err) } @@ -122,15 +126,19 @@ func Reload() E.Error { // -> replace config -> start new subtasks instance.task.Finish("config changed") instance = newCfg - instance.StartProxyProviders() + instance.Start() return nil } -func Value() types.Config { - return *instance.value +func (cfg *Config) Value() *types.Config { + return instance.value } -func GetAutoCertProvider() *autocert.Provider { +func (cfg *Config) Reload() E.Error { + return Reload() +} + +func (cfg *Config) AutoCertProvider() *autocert.Provider { return instance.autocertProvider } @@ -138,6 +146,26 @@ func (cfg *Config) Task() *task.Task { return cfg.task } +func (cfg *Config) Start() { + cfg.StartAutoCert() + cfg.StartProxyProviders() + cfg.StartServers() +} + +func (cfg *Config) StartAutoCert() { + autocert := cfg.autocertProvider + if autocert == nil { + logging.Info().Msg("autocert not configured") + return + } + + if err := autocert.Setup(); err != nil { + E.LogFatal("autocert setup error", err) + } else { + autocert.ScheduleRenewal(cfg.task) + } +} + func (cfg *Config) StartProxyProviders() { errs := cfg.providers.CollectErrorsParallel( func(_ string, p *proxy.Provider) error { @@ -149,6 +177,30 @@ func (cfg *Config) StartProxyProviders() { } } +func (cfg *Config) StartServers() { + server.StartServer(cfg.task, server.Options{ + Name: "proxy", + CertProvider: cfg.AutoCertProvider(), + HTTPAddr: common.ProxyHTTPAddr, + HTTPSAddr: common.ProxyHTTPSAddr, + Handler: cfg.entrypoint, + }) + server.StartServer(cfg.task, server.Options{ + Name: "api", + CertProvider: cfg.AutoCertProvider(), + HTTPAddr: common.APIHTTPAddr, + Handler: api.NewHandler(cfg), + }) + if common.PrometheusEnabled { + server.StartServer(cfg.task, server.Options{ + Name: "metrics", + CertProvider: cfg.AutoCertProvider(), + HTTPAddr: common.MetricsHTTPAddr, + Handler: metrics.NewHandler(), + }) + } +} + func (cfg *Config) load() E.Error { const errMsg = "config load error" @@ -164,8 +216,8 @@ func (cfg *Config) load() E.Error { // errors are non fatal below errs := E.NewBuilder(errMsg) - errs.Add(entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) - errs.Add(entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) + errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) + errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.loadRouteProviders(&model.Providers)) @@ -176,7 +228,8 @@ func (cfg *Config) load() E.Error { model.MatchDomains[i] = "." + domain } } - entrypoint.SetFindRouteDomains(model.MatchDomains) + cfg.entrypoint.SetFindRouteDomains(model.MatchDomains) + return errs.Error() } diff --git a/internal/config/query.go b/internal/config/query.go index 4634b7a..a3533f3 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -1,20 +1,14 @@ package config import ( - "strings" - - "github.com/yusing/go-proxy/internal/homepage" route "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/route/entry" - proxy "github.com/yusing/go-proxy/internal/route/provider" - "github.com/yusing/go-proxy/internal/route/routes" + "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/route/types" - "github.com/yusing/go-proxy/internal/utils/strutils" ) -func DumpEntries() map[string]*types.RawEntry { +func (cfg *Config) DumpEntries() map[string]*types.RawEntry { entries := make(map[string]*types.RawEntry) - instance.providers.RangeAll(func(_ string, p *proxy.Provider) { + cfg.providers.RangeAll(func(_ string, p *provider.Provider) { p.RangeRoutes(func(alias string, r *route.Route) { entries[alias] = r.Entry }) @@ -22,107 +16,20 @@ func DumpEntries() map[string]*types.RawEntry { return entries } -func DumpProviders() map[string]*proxy.Provider { - entries := make(map[string]*proxy.Provider) - instance.providers.RangeAll(func(name string, p *proxy.Provider) { +func (cfg *Config) DumpProviders() map[string]*provider.Provider { + entries := make(map[string]*provider.Provider) + cfg.providers.RangeAll(func(name string, p *provider.Provider) { entries[name] = p }) return entries } -func HomepageConfig() homepage.Config { - hpCfg := homepage.NewHomePageConfig() - routes.GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { - en := r.RawEntry() - item := en.Homepage - if item == nil { - item = new(homepage.Item) - item.Show = true - } - - if !item.IsEmpty() { - item.Show = true - } - - if !item.Show { - return - } - - item.Alias = alias - - if item.Name == "" { - item.Name = strutils.Title( - strings.ReplaceAll( - strings.ReplaceAll(alias, "-", " "), - "_", " ", - ), - ) - } - - if instance.value.Homepage.UseDefaultCategories { - if en.Container != nil && item.Category == "" { - if category, ok := homepage.PredefinedCategories[en.Container.ImageName]; ok { - item.Category = category - } - } - - if item.Category == "" { - if category, ok := homepage.PredefinedCategories[strings.ToLower(alias)]; ok { - item.Category = category - } - } - } - - switch { - case entry.IsDocker(r): - if item.Category == "" { - item.Category = "Docker" - } - item.SourceType = string(proxy.ProviderTypeDocker) - case entry.UseLoadBalance(r): - if item.Category == "" { - item.Category = "Load-balanced" - } - item.SourceType = "loadbalancer" - default: - if item.Category == "" { - item.Category = "Others" - } - item.SourceType = string(proxy.ProviderTypeFile) - } - - item.AltURL = r.TargetURL().String() - hpCfg.Add(item) - }) - return hpCfg -} - -func RoutesByAlias(typeFilter ...route.RouteType) map[string]any { - rts := make(map[string]any) - if len(typeFilter) == 0 || typeFilter[0] == "" { - typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream} - } - for _, t := range typeFilter { - switch t { - case route.RouteTypeReverseProxy: - routes.GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { - rts[alias] = r - }) - case route.RouteTypeStream: - routes.GetStreamRoutes().RangeAll(func(alias string, r types.StreamRoute) { - rts[alias] = r - }) - } - } - return rts -} - -func Statistics() map[string]any { +func (cfg *Config) Statistics() map[string]any { nTotalStreams := 0 nTotalRPs := 0 - providerStats := make(map[string]proxy.ProviderStats) + providerStats := make(map[string]provider.ProviderStats) - instance.providers.RangeAll(func(name string, p *proxy.Provider) { + cfg.providers.RangeAll(func(name string, p *provider.Provider) { stats := p.Statistics() providerStats[name] = stats diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 2e27891..9590337 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -3,6 +3,8 @@ package types import ( "github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/utils" + + E "github.com/yusing/go-proxy/internal/error" ) type ( @@ -24,6 +26,12 @@ type ( AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"` } NotificationConfig map[string]any + + ConfigInstance interface { + Value() *Config + Reload() E.Error + Statistics() map[string]any + } ) func DefaultConfig() *Config { @@ -35,6 +43,11 @@ func DefaultConfig() *Config { } } +func Validate(data []byte) E.Error { + var model Config + return utils.DeserializeYAML(data, &model) +} + func init() { utils.RegisterDefaultValueFactory(DefaultConfig) } diff --git a/internal/docker/container.go b/internal/docker/container.go index 129d289..b805d98 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -28,16 +28,17 @@ type ( PrivateIP string `json:"private_ip"` NetworkMode string `json:"network_mode"` - Aliases []string `json:"aliases"` - IsExcluded bool `json:"is_excluded"` - IsExplicit bool `json:"is_explicit"` - IsDatabase bool `json:"is_database"` - IdleTimeout string `json:"idle_timeout,omitempty"` - WakeTimeout string `json:"wake_timeout,omitempty"` - StopMethod string `json:"stop_method,omitempty"` - StopTimeout string `json:"stop_timeout,omitempty"` // stop_method = "stop" only - StopSignal string `json:"stop_signal,omitempty"` // stop_method = "stop" | "kill" only - Running bool `json:"running"` + Aliases []string `json:"aliases"` + IsExcluded bool `json:"is_excluded"` + IsExplicit bool `json:"is_explicit"` + IsDatabase bool `json:"is_database"` + IdleTimeout string `json:"idle_timeout,omitempty"` + WakeTimeout string `json:"wake_timeout,omitempty"` + StopMethod string `json:"stop_method,omitempty"` + StopTimeout string `json:"stop_timeout,omitempty"` // stop_method = "stop" only + StopSignal string `json:"stop_signal,omitempty"` // stop_method = "stop" | "kill" only + StartEndpoint string `json:"start_endpoint,omitempty"` + Running bool `json:"running"` } ) @@ -58,16 +59,17 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { PrivatePortMapping: helper.getPrivatePortMapping(), NetworkMode: c.HostConfig.NetworkMode, - Aliases: helper.getAliases(), - IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), - IsExplicit: isExplicit, - IsDatabase: helper.isDatabase(), - IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), - WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout), - StopMethod: helper.getDeleteLabel(LabelStopMethod), - StopTimeout: helper.getDeleteLabel(LabelStopTimeout), - StopSignal: helper.getDeleteLabel(LabelStopSignal), - Running: c.Status == "running" || c.State == "running", + Aliases: helper.getAliases(), + IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), + IsExplicit: isExplicit, + IsDatabase: helper.isDatabase(), + IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), + WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout), + StopMethod: helper.getDeleteLabel(LabelStopMethod), + StopTimeout: helper.getDeleteLabel(LabelStopTimeout), + StopSignal: helper.getDeleteLabel(LabelStopSignal), + StartEndpoint: helper.getDeleteLabel(LabelStartEndpoint), + Running: c.Status == "running" || c.State == "running", } res.setPrivateIP(helper) res.setPublicIP() diff --git a/internal/docker/idlewatcher/types/config.go b/internal/docker/idlewatcher/types/config.go index cb8f491..a813cec 100644 --- a/internal/docker/idlewatcher/types/config.go +++ b/internal/docker/idlewatcher/types/config.go @@ -2,6 +2,8 @@ package types import ( "errors" + "net/url" + "strings" "time" "github.com/yusing/go-proxy/internal/docker" @@ -10,11 +12,12 @@ import ( type ( Config struct { - IdleTimeout time.Duration `json:"idle_timeout,omitempty"` - WakeTimeout time.Duration `json:"wake_timeout,omitempty"` - StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument - StopMethod StopMethod `json:"stop_method,omitempty"` - StopSignal Signal `json:"stop_signal,omitempty"` + IdleTimeout time.Duration `json:"idle_timeout,omitempty"` + WakeTimeout time.Duration `json:"wake_timeout,omitempty"` + StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument + StopMethod StopMethod `json:"stop_method,omitempty"` + StopSignal Signal `json:"stop_signal,omitempty"` + StartEndpoint string `json:"start_endpoint,omitempty"` // Optional path that must be hit to start container DockerHost string `json:"docker_host,omitempty"` ContainerName string `json:"container_name,omitempty"` @@ -58,17 +61,19 @@ func ValidateConfig(cont *docker.Container) (*Config, E.Error) { stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout) stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod) signal := E.Collect(errs, validateSignal, cont.StopSignal) + startEndpoint := E.Collect(errs, validateStartEndpoint, cont.StartEndpoint) if errs.HasError() { return nil, errs.Error() } return &Config{ - IdleTimeout: idleTimeout, - WakeTimeout: wakeTimeout, - StopTimeout: int(stopTimeout.Seconds()), - StopMethod: stopMethod, - StopSignal: signal, + IdleTimeout: idleTimeout, + WakeTimeout: wakeTimeout, + StopTimeout: int(stopTimeout.Seconds()), + StopMethod: stopMethod, + StopSignal: signal, + StartEndpoint: startEndpoint, DockerHost: cont.DockerHost, ContainerName: cont.ContainerName, @@ -104,3 +109,21 @@ func validateStopMethod(s string) (StopMethod, error) { return "", errors.New("invalid stop method " + s) } } + +func validateStartEndpoint(s string) (string, error) { + if s == "" { + return "", nil + } + // checks needed as of Go 1.6 because of change https://github.com/golang/go/commit/617c93ce740c3c3cc28cdd1a0d712be183d0b328#diff-6c2d018290e298803c0c9419d8739885L195 + // emulate browser and strip the '#' suffix prior to validation. see issue-#237 + if i := strings.Index(s, "#"); i > -1 { + s = s[:i] + } + if len(s) == 0 { + return "", errors.New("start endpoint must not be empty if defined") + } + if _, err := url.ParseRequestURI(s); err != nil { + return "", err + } + return s, nil +} diff --git a/internal/docker/idlewatcher/types/config_test.go b/internal/docker/idlewatcher/types/config_test.go new file mode 100644 index 0000000..730c0c7 --- /dev/null +++ b/internal/docker/idlewatcher/types/config_test.go @@ -0,0 +1,47 @@ +package types + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestValidateStartEndpoint(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "valid", + input: "/start", + wantErr: false, + }, + { + name: "invalid", + input: "../foo", + wantErr: true, + }, + { + name: "single fragment", + input: "#", + wantErr: true, + }, + { + name: "empty", + input: "", + wantErr: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s, err := validateStartEndpoint(tc.input) + if err == nil { + ExpectEqual(t, s, tc.input) + } + if (err != nil) != tc.wantErr { + t.Errorf("validateStartEndpoint() error = %v, wantErr %t", err, tc.wantErr) + } + }) + } +} diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 6439e65..c2c34c8 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -8,7 +8,7 @@ import ( "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/metrics" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" net "github.com/yusing/go-proxy/internal/net/types" route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/task" @@ -22,7 +22,7 @@ type ( waker struct { _ U.NoCopy - rp *gphttp.ReverseProxy + rp *reverseproxy.ReverseProxy stream net.Stream hc health.HealthChecker metric *metrics.Gauge @@ -38,7 +38,7 @@ const ( // TODO: support stream -func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { +func newWaker(parent task.Parent, entry route.Entry, rp *reverseproxy.ReverseProxy, stream net.Stream) (Waker, E.Error) { hcCfg := entry.RawEntry().HealthCheck hcCfg.Timeout = idleWakerCheckTimeout @@ -71,7 +71,7 @@ func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, st } // lifetime should follow route provider. -func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { +func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *reverseproxy.ReverseProxy) (Waker, E.Error) { return newWaker(parent, entry, rp, nil) } diff --git a/internal/docker/idlewatcher/waker_http.go b/internal/docker/idlewatcher/waker_http.go index 9f8085d..cb95f10 100644 --- a/internal/docker/idlewatcher/waker_http.go +++ b/internal/docker/idlewatcher/waker_http.go @@ -34,6 +34,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN return true } + // Check if start endpoint is configured and request path matches + if w.StartEndpoint != "" && r.URL.Path != w.StartEndpoint { + http.Error(rw, "Forbidden: Container can only be started via configured start endpoint", http.StatusForbidden) + return false + } + if r.Body != nil { defer r.Body.Close() } diff --git a/internal/docker/labels.go b/internal/docker/labels.go index 8c78e79..0a9e0a5 100644 --- a/internal/docker/labels.go +++ b/internal/docker/labels.go @@ -5,11 +5,12 @@ const ( NSProxy = "proxy" - LabelAliases = NSProxy + ".aliases" - LabelExclude = NSProxy + ".exclude" - LabelIdleTimeout = NSProxy + ".idle_timeout" - LabelWakeTimeout = NSProxy + ".wake_timeout" - LabelStopMethod = NSProxy + ".stop_method" - LabelStopTimeout = NSProxy + ".stop_timeout" - LabelStopSignal = NSProxy + ".stop_signal" + LabelAliases = NSProxy + ".aliases" + LabelExclude = NSProxy + ".exclude" + LabelIdleTimeout = NSProxy + ".idle_timeout" + LabelWakeTimeout = NSProxy + ".wake_timeout" + LabelStopMethod = NSProxy + ".stop_method" + LabelStopTimeout = NSProxy + ".stop_timeout" + LabelStopSignal = NSProxy + ".stop_signal" + LabelStartEndpoint = NSProxy + ".start_endpoint" ) diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 0046a93..a2f5c0a 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "strings" - "sync" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/accesslog" @@ -17,32 +16,31 @@ import ( "github.com/yusing/go-proxy/internal/utils/strutils" ) -var findRouteFunc = findRouteAnyDomain - -var ( - epMiddleware *middleware.Middleware - epMiddlewareMu sync.Mutex - - epAccessLogger *accesslog.AccessLogger - epAccessLoggerMu sync.Mutex -) +type Entrypoint struct { + middleware *middleware.Middleware + accessLogger *accesslog.AccessLogger + findRouteFunc func(host string) (route.HTTPRoute, error) +} var ErrNoSuchRoute = errors.New("no such route") -func SetFindRouteDomains(domains []string) { - if len(domains) == 0 { - findRouteFunc = findRouteAnyDomain - } else { - findRouteFunc = findRouteByDomains(domains) +func NewEntrypoint() *Entrypoint { + return &Entrypoint{ + findRouteFunc: findRouteAnyDomain, } } -func SetMiddlewares(mws []map[string]any) error { - epMiddlewareMu.Lock() - defer epMiddlewareMu.Unlock() +func (ep *Entrypoint) SetFindRouteDomains(domains []string) { + if len(domains) == 0 { + ep.findRouteFunc = findRouteAnyDomain + } else { + ep.findRouteFunc = findRouteByDomains(domains) + } +} +func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error { if len(mws) == 0 { - epMiddleware = nil + ep.middleware = nil return nil } @@ -50,22 +48,19 @@ func SetMiddlewares(mws []map[string]any) error { if err != nil { return err } - epMiddleware = mid + ep.middleware = mid logger.Debug().Msg("entrypoint middleware loaded") return nil } -func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { - epAccessLoggerMu.Lock() - defer epAccessLoggerMu.Unlock() - +func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { if cfg == nil { - epAccessLogger = nil + ep.accessLogger = nil return } - epAccessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) + ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) if err != nil { return } @@ -73,28 +68,18 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { return } -func Handler(w http.ResponseWriter, r *http.Request) { - mux, err := findRouteFunc(r.Host) +func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { + mux, err := ep.findRouteFunc(r.Host) if err == nil { - if epAccessLogger != nil { - epMiddlewareMu.Lock() - if epAccessLogger != nil { - w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { - epAccessLogger.Log(r, resp) - return nil - }) - } - epMiddlewareMu.Unlock() + if ep.accessLogger != nil { + w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { + ep.accessLogger.Log(r, resp) + return nil + }) } - if epMiddleware != nil { - epMiddlewareMu.Lock() - if epMiddleware != nil { - mid := epMiddleware - epMiddlewareMu.Unlock() - mid.ServeHTTP(mux.ServeHTTP, w, r) - return - } - epMiddlewareMu.Unlock() + if ep.middleware != nil { + ep.middleware.ServeHTTP(mux.ServeHTTP, w, r) + return } mux.ServeHTTP(w, r) return diff --git a/internal/entrypoint/entrypoint_test.go b/internal/entrypoint/entrypoint_test.go index 65153a4..438bc1f 100644 --- a/internal/entrypoint/entrypoint_test.go +++ b/internal/entrypoint/entrypoint_test.go @@ -8,18 +8,19 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) -var r route.HTTPRoute +var ( + r route.HTTPRoute + ep = NewEntrypoint() +) func run(t *testing.T, match []string, noMatch []string) { t.Helper() t.Cleanup(routes.TestClear) - t.Cleanup(func() { - SetFindRouteDomains(nil) - }) + t.Cleanup(func() { ep.SetFindRouteDomains(nil) }) for _, test := range match { t.Run(test, func(t *testing.T) { - found, err := findRouteFunc(test) + found, err := ep.findRouteFunc(test) ExpectNoError(t, err) ExpectTrue(t, found == &r) }) @@ -27,7 +28,7 @@ func run(t *testing.T, match []string, noMatch []string) { for _, test := range noMatch { t.Run(test, func(t *testing.T) { - _, err := findRouteFunc(test) + _, err := ep.findRouteFunc(test) ExpectError(t, ErrNoSuchRoute, err) }) } @@ -72,7 +73,7 @@ func TestFindRouteExactHostMatch(t *testing.T) { } func TestFindRouteByDomains(t *testing.T) { - SetFindRouteDomains([]string{ + ep.SetFindRouteDomains([]string{ ".domain.com", ".sub.domain.com", }) @@ -97,7 +98,7 @@ func TestFindRouteByDomains(t *testing.T) { } func TestFindRouteByDomainsExactMatch(t *testing.T) { - SetFindRouteDomains([]string{ + ep.SetFindRouteDomains([]string{ ".domain.com", ".sub.domain.com", }) diff --git a/internal/error/subject.go b/internal/error/subject.go index b649f0e..f46727d 100644 --- a/internal/error/subject.go +++ b/internal/error/subject.go @@ -8,8 +8,8 @@ import ( //nolint:errname type withSubject struct { - Subject string `json:"subject"` - Err error `json:"err"` + Subjects []string `json:"subjects"` + Err error `json:"err"` } const subjectSep = " > " @@ -30,13 +30,18 @@ func PrependSubject(subject string, err error) error { case Error: return err.Subject(subject) } - return &withSubject{subject, err} + return &withSubject{[]string{subject}, err} } func (err *withSubject) Prepend(subject string) *withSubject { clone := *err if subject != "" { - clone.Subject = subject + subjectSep + clone.Subject + switch subject[0] { + case '[', '(', '{': + clone.Subjects[len(clone.Subjects)-1] += subject + default: + clone.Subjects = append(clone.Subjects, subject) + } } return &clone } @@ -50,7 +55,22 @@ func (err *withSubject) Unwrap() error { } func (err *withSubject) Error() string { - subjects := strings.Split(err.Subject, subjectSep) - subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1]) - return strings.Join(subjects, subjectSep) + ": " + err.Err.Error() + // subject is in reversed order + n := len(err.Subjects) + size := 0 + errStr := err.Err.Error() + var sb strings.Builder + for _, s := range err.Subjects { + size += len(s) + } + sb.Grow(size + 2 + n*len(subjectSep) + len(errStr)) + + for i := n - 1; i > 0; i-- { + sb.WriteString(err.Subjects[i]) + sb.WriteString(subjectSep) + } + sb.WriteString(highlight(err.Subjects[0])) + sb.WriteString(": ") + sb.WriteString(errStr) + return sb.String() } diff --git a/internal/net/http/accesslog/access_logger.go b/internal/net/http/accesslog/access_logger.go index 7738287..6d493f8 100644 --- a/internal/net/http/accesslog/access_logger.go +++ b/internal/net/http/accesslog/access_logger.go @@ -129,7 +129,6 @@ func (l *AccessLogger) Flush(force bool) { l.write(l.buf.Bytes()) l.buf.Reset() l.bufMu.Unlock() - logger.Debug().Msg("access log flushed to " + l.io.Name()) } } @@ -170,5 +169,7 @@ func (l *AccessLogger) write(data []byte) { l.io.Unlock() if err != nil { l.handleErr(err) + } else { + logger.Debug().Msg("access log flushed to " + l.io.Name()) } } diff --git a/internal/net/http/accesslog/file_logger.go b/internal/net/http/accesslog/file_logger.go index 7b3aec7..e2ba4bc 100644 --- a/internal/net/http/accesslog/file_logger.go +++ b/internal/net/http/accesslog/file_logger.go @@ -3,36 +3,66 @@ package accesslog import ( "fmt" "os" + "path" "sync" "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" ) type File struct { *os.File sync.Mutex + + // os.File.Name() may not equal to key of `openedFiles`. + // Store it for later delete from `openedFiles`. + path string + + refCount *utils.RefCount } var ( - openedFiles = make(map[string]AccessLogIO) + openedFiles = make(map[string]*File) openedFilesMu sync.Mutex ) func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { openedFilesMu.Lock() - var io AccessLogIO - if opened, ok := openedFiles[cfg.Path]; ok { - io = opened + var file *File + path := path.Clean(cfg.Path) + if opened, ok := openedFiles[path]; ok { + opened.refCount.Add() + file = opened } else { - f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0o644) if err != nil { + openedFilesMu.Unlock() return nil, fmt.Errorf("access log open error: %w", err) } - io = &File{File: f} - openedFiles[cfg.Path] = io + file = &File{File: f, path: path, refCount: utils.NewRefCounter()} + openedFiles[path] = file + go file.closeOnZero() } openedFilesMu.Unlock() - return NewAccessLogger(parent, io, cfg), nil + return NewAccessLogger(parent, file, cfg), nil +} + +func (f *File) Close() error { + f.refCount.Sub() + return nil +} + +func (f *File) closeOnZero() { + defer logger.Debug(). + Str("path", f.path). + Msg("access log closed") + + <-f.refCount.Zero() + + openedFilesMu.Lock() + delete(openedFiles, f.path) + openedFilesMu.Unlock() + f.File.Close() } diff --git a/internal/net/http/header_utils.go b/internal/net/http/header_utils.go index f086c07..db8c78f 100644 --- a/internal/net/http/header_utils.go +++ b/internal/net/http/header_utils.go @@ -2,6 +2,10 @@ package http import ( "net/http" + "net/textproto" + + "github.com/yusing/go-proxy/internal/utils/strutils" + "golang.org/x/net/http/httpguts" ) const ( @@ -22,6 +26,48 @@ const ( HeaderContentLength = "Content-Length" ) +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func UpgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return h.Get("Upgrade") +} + +// RemoveHopByHopHeaders removes hop-by-hop headers. +func RemoveHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for _, sf := range strutils.SplitComma(f) { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} + func RemoveHop(h http.Header) { reqUpType := UpgradeType(h) RemoveHopByHopHeaders(h) diff --git a/internal/net/http/methods.go b/internal/net/http/methods.go new file mode 100644 index 0000000..a46923d --- /dev/null +++ b/internal/net/http/methods.go @@ -0,0 +1,20 @@ +package http + +import "net/http" + +var validMethods = map[string]struct{}{ + http.MethodGet: {}, + http.MethodHead: {}, + http.MethodPost: {}, + http.MethodPut: {}, + http.MethodPatch: {}, + http.MethodDelete: {}, + http.MethodConnect: {}, + http.MethodOptions: {}, + http.MethodTrace: {}, +} + +func IsMethodValid(method string) bool { + _, ok := validMethods[method] + return ok +} diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index d6726a6..e123c86 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -4,7 +4,10 @@ import ( "net" "net/http" + "github.com/go-playground/validator/v10" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -16,7 +19,7 @@ type ( } CIDRWhitelistOpts struct { Allow []*types.CIDR `validate:"min=1"` - StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"` + StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"` Message string } ) @@ -30,6 +33,13 @@ var ( } ) +func init() { + utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool { + statusCode := fl.Field().Int() + return gphttp.IsStatusCodeValid(int(statusCode)) + }) +} + // setup implements MiddlewareWithSetup. func (wl *cidrWhitelist) setup() { wl.CIDRWhitelistOpts = cidrWhitelistDefaults diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index b9bd3a1..64fc9e8 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -24,6 +24,18 @@ func TestCIDRWhitelistValidation(t *testing.T) { "message": testMessage, }) ExpectNoError(t, err) + _, err = CIDRWhiteList.New(OptionsRaw{ + "allow": []string{"192.168.2.100/32"}, + "message": testMessage, + "status": 403, + }) + ExpectNoError(t, err) + _, err = CIDRWhiteList.New(OptionsRaw{ + "allow": []string{"192.168.2.100/32"}, + "message": testMessage, + "status_code": 403, + }) + ExpectNoError(t, err) }) t.Run("missing allow", func(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index ff854e9..78410d3 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -9,14 +9,15 @@ import ( E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/utils" ) type ( Error = E.Error - ReverseProxy = gphttp.ReverseProxy - ProxyRequest = gphttp.ProxyRequest + ReverseProxy = reverseproxy.ReverseProxy + ProxyRequest = reverseproxy.ProxyRequest ImplNewFunc = func() any OptionsRaw = map[string]any @@ -93,9 +94,9 @@ func (m *Middleware) finalize() { } func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { - if m.construct == nil { - if optsRaw != nil { - panic("bug: middleware already constructed") + if m.construct == nil { // likely a middleware from compose + if len(optsRaw) != 0 { + return nil, E.New("additional options not allowed for middleware ").Subject(m.name) } return m, nil } diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 74ccb45..cfea2c3 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -61,17 +61,38 @@ func LoadComposeFiles() { logger.Err(err).Msg("failed to list middleware definitions") return } + for _, defFile := range middlewareDefs { + voidErrs := E.NewBuilder("") // ignore these errors, will be added in next step + mws := BuildMiddlewaresFromComposeFile(defFile, voidErrs) + if len(mws) == 0 { + continue + } + for name, m := range mws { + name = strutils.ToLowerNoSnake(name) + if _, ok := allMiddlewares[name]; ok { + errs.Add(ErrDuplicatedMiddleware.Subject(name)) + continue + } + allMiddlewares[name] = m + logger.Info(). + Str("src", path.Base(defFile)). + Str("name", name). + Msg("middleware loaded") + } + } + // build again to resolve cross references for _, defFile := range middlewareDefs { mws := BuildMiddlewaresFromComposeFile(defFile, errs) if len(mws) == 0 { continue } for name, m := range mws { + name = strutils.ToLowerNoSnake(name) if _, ok := allMiddlewares[name]; ok { - errs.Add(ErrDuplicatedMiddleware.Subject(name)) + // already loaded above continue } - allMiddlewares[strutils.ToLowerNoSnake(name)] = m + allMiddlewares[name] = m logger.Info(). Str("src", path.Base(defFile)). Str("name", name). diff --git a/internal/net/http/middleware/set_upstream_headers.go b/internal/net/http/middleware/set_upstream_headers.go index e963cbf..009fc84 100644 --- a/internal/net/http/middleware/set_upstream_headers.go +++ b/internal/net/http/middleware/set_upstream_headers.go @@ -4,6 +4,7 @@ import ( "net/http" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" ) // internal use only. @@ -13,7 +14,7 @@ type setUpstreamHeaders struct { var suh = NewMiddleware[setUpstreamHeaders]() -func newSetUpstreamHeaders(rp *gphttp.ReverseProxy) *Middleware { +func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware { m, err := suh.New(OptionsRaw{ "name": rp.TargetName, "scheme": rp.TargetURL.Scheme, diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index eb6fdf3..dceeb39 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -10,7 +10,7 @@ import ( "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/types" ) @@ -139,7 +139,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E rr.parent = http.DefaultTransport } - rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr) + rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr) mid, setOptErr := middleware.New(args.middlewareOpt) if setOptErr != nil { diff --git a/internal/net/http/reverseproxy/reverse_proxy_mod.go b/internal/net/http/reverseproxy/reverse_proxy_mod.go new file mode 100644 index 0000000..0bb0d4b --- /dev/null +++ b/internal/net/http/reverseproxy/reverse_proxy_mod.go @@ -0,0 +1,577 @@ +// Copyright 2011 The Go Authors. +// Modified from the Go project under the a BSD-style License (https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/net/http/httputil/reverseproxy.go) +// https://cs.opensource.google/go/go/+/master:LICENSE + +package reverseproxy + +// This is a small mod on net/http/httputil/reverseproxy.go +// that boosts performance in some cases +// and compatible to other modules of this project +// Copyright (c) 2024 yusing + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/metrics" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/net/types" + U "github.com/yusing/go-proxy/internal/utils" + "golang.org/x/net/http/httpguts" +) + +// A ProxyRequest contains a request to be rewritten by a [ReverseProxy]. +type ProxyRequest struct { + // In is the request received by the proxy. + // The Rewrite function must not modify In. + In *http.Request + + // Out is the request which will be sent by the proxy. + // The Rewrite function may modify or replace this request. + // Hop-by-hop headers are removed from this request + // before Rewrite is called. + Out *http.Request +} + +// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and +// X-Forwarded-Proto headers of the outbound request. +// +// - The X-Forwarded-For header is set to the client IP address. +// - The X-Forwarded-Host header is set to the host name requested +// by the client. +// - The X-Forwarded-Proto header is set to "http" or "https", depending +// on whether the inbound request was made on a TLS-enabled connection. +// +// If the outbound request contains an existing X-Forwarded-For header, +// SetXForwarded appends the client IP address to it. To append to the +// inbound request's X-Forwarded-For header (the default behavior of +// [ReverseProxy] when using a Director function), copy the header +// from the inbound request before calling SetXForwarded: +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] +// r.SetXForwarded() +// } + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +// +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. +type ReverseProxy struct { + zerolog.Logger + + // The transport used to perform proxy requests. + Transport http.RoundTripper + + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called before ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. + ModifyResponse func(*http.Response) error + AccessLogger *accesslog.AccessLogger + + HandlerFunc http.HandlerFunc + + TargetName string + TargetURL types.URL +} + +type httpMetricLogger struct { + http.ResponseWriter + timestamp time.Time + labels *metrics.HTTPRouteMetricLabels +} + +var logger = logging.With().Str("module", "reverse_proxy").Logger() + +// WriteHeader implements http.ResponseWriter. +func (l *httpMetricLogger) WriteHeader(status int) { + l.ResponseWriter.WriteHeader(status) + duration := time.Since(l.timestamp) + go func() { + m := metrics.GetRouteMetrics() + m.HTTPReqTotal.Inc() + m.HTTPReqElapsed.With(l.labels).Set(float64(duration.Milliseconds())) + + // ignore 1xx + switch { + case status >= 500: + m.HTTP5xx.With(l.labels).Inc() + case status >= 400: + m.HTTP4xx.With(l.labels).Inc() + case status >= 200: + m.HTTP2xx3xx.With(l.labels).Inc() + } + }() +} + +func (l *httpMetricLogger) Unwrap() http.ResponseWriter { + return l.ResponseWriter +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + +// NewReverseProxy returns a new [ReverseProxy] that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) *ReverseProxy { + if transport == nil { + panic("nil transport") + } + rp := &ReverseProxy{ + Logger: logger.With().Str("name", name).Logger(), + Transport: transport, + TargetName: name, + TargetURL: target, + } + rp.HandlerFunc = rp.handler + return rp +} + +func (p *ReverseProxy) UnregisterMetrics() { + metrics.GetRouteMetrics().UnregisterService(p.TargetName) +} + +func (p *ReverseProxy) rewriteRequestURL(req *http.Request) { + targetQuery := p.TargetURL.RawQuery + req.URL.Scheme = p.TargetURL.Scheme + req.URL.Host = p.TargetURL.Host + req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL.URL, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) { + switch { + case errors.Is(err, context.Canceled), + errors.Is(err, io.EOF): + logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error") + default: + logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error") + } + if writeHeader { + rw.WriteHeader(http.StatusInternalServerError) + } + if p.AccessLogger != nil { + p.AccessLogger.LogError(r, err) + } +} + +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, origReq, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + res.Request = origReq + err := p.ModifyResponse(res) + res.Request = req + if err != nil { + res.Body.Close() + p.errorHandler(rw, req, err, true) + return false + } + return true +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + p.HandlerFunc(rw, req) +} + +func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { + visitorIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + visitorIP = req.RemoteAddr + } + + if common.PrometheusEnabled { + t := time.Now() + // req.RemoteAddr had been modified by middleware (if any) + lbls := &metrics.HTTPRouteMetricLabels{ + Service: p.TargetName, + Method: req.Method, + Host: req.Host, + Visitor: visitorIP, + Path: req.URL.Path, + } + rw = &httpMetricLogger{ + ResponseWriter: rw, + timestamp: t, + labels: lbls, + } + } + + transport := p.Transport + + ctx := req.Context() + /* trunk-ignore(golangci-lint/revive) */ + if ctx.Done() != nil { + // CloseNotifier predates context.Context, and has been + // entirely superseded by it. If the request contains + // a Context that carries a cancellation signal, don't + // bother spinning up a goroutine to watch the CloseNotify + // channel (if any). + // + // If the request Context has a nil Done channel (which + // means it is either context.Background, or a custom + // Context implementation with no cancellation signal), + // then consult the CloseNotifier if available. + } else if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.Clone(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Body != nil { + // Reading from the request body after returning from a handler is not + // allowed, and the RoundTrip goroutine that reads the Body can outlive + // this handler. This can lead to a crash if the handler panics (see + // Issue 46866). Although calling Close doesn't guarantee there isn't + // any Read in flight after the handle returns, in practice it's safe to + // read after closing it. + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + p.rewriteRequestURL(outreq) + outreq.Close = false + + reqUpType := gphttp.UpgradeType(outreq.Header) + if !IsPrint(reqUpType) { + p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true) + return + } + + req.Header.Del("Forwarded") + gphttp.RemoveHopByHopHeaders(outreq.Header) + + // Issue 21096: tell backend applications that care about trailer support + // that we support trailers. (We do, but we don't go out of our way to + // advertise that unless the incoming client request thought it was worth + // mentioning.) Note that we look at req.Header, not outreq.Header, since + // the latter has passed through removeHopByHopHeaders. + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + + if strings.EqualFold(reqUpType, "websocket") { + cleanWebsocketHeaders(outreq) + } + } + + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header[gphttp.HeaderXForwardedFor] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + xff := visitorIP + if len(prior) > 0 { + xff = strings.Join(prior, ", ") + ", " + xff + } + if !omit { + outreq.Header.Set(gphttp.HeaderXForwardedFor, xff) + } + + var reqScheme string + if req.TLS != nil { + reqScheme = "https" + } else { + reqScheme = "http" + } + + outreq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) + outreq.Header.Set(gphttp.HeaderXForwardedProto, reqScheme) + outreq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) + outreq.Header.Set(gphttp.HeaderXForwardedURI, req.RequestURI) + + if _, ok := outreq.Header["User-Agent"]; !ok { + // If the outbound request doesn't have a User-Agent header set, + // don't send the default Go HTTP client User-Agent. + outreq.Header.Set("User-Agent", "") + } + + var ( + roundTripMutex sync.Mutex + roundTripDone bool + ) + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMutex.Lock() + defer roundTripMutex.Unlock() + if roundTripDone { + // If RoundTrip has returned, don't try to further modify + // the ResponseWriter's header map. + return nil + } + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + clear(h) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + + roundTripMutex.Lock() + roundTripDone = true + roundTripMutex.Unlock() + if err != nil { + p.errorHandler(rw, outreq, err, false) + res = &http.Response{ + Status: http.StatusText(http.StatusBadGateway), + StatusCode: http.StatusBadGateway, + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))), + Request: req, + TLS: req.TLS, + } + } + + if p.AccessLogger != nil { + defer func() { + p.AccessLogger.Log(req, res) + }() + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, req, outreq) { + return + } + p.handleUpgradeResponse(rw, outreq, res) + return + } + + gphttp.RemoveHopByHopHeaders(res.Header) + + if !p.modifyResponse(rw, res, req, outreq) { + return + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + _, err = io.Copy(rw, res.Body) + if err != nil { + if !errors.Is(err, context.Canceled) { + p.errorHandler(rw, req, err, true) + } + res.Body.Close() + return + } + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + http.NewResponseController(rw).Flush() + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +// reference: https://github.com/traefik/traefik/blob/master/pkg/proxy/httputil/proxy.go +// https://tools.ietf.org/html/rfc6455#page-20 +func cleanWebsocketHeaders(req *http.Request) { + req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"] + delete(req.Header, "Sec-Websocket-Key") + + req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"] + delete(req.Header, "Sec-Websocket-Extensions") + + req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"] + delete(req.Header, "Sec-Websocket-Accept") + + req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"] + delete(req.Header, "Sec-Websocket-Protocol") + + req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"] + delete(req.Header, "Sec-Websocket-Version") +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := gphttp.UpgradeType(req.Header) + resUpType := gphttp.UpgradeType(res.Header) + if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. + p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) + return + } + if !strings.EqualFold(reqUpType, resUpType) { + p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true) + return + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true) + return + } + + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.errorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw), true) + return + } + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + + if hijackErr != nil { + p.errorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", hijackErr), true) + return + } + defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + /* trunk-ignore(golangci-lint/errorlint) */ + p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true) + return + } + if err := brw.Flush(); err != nil { + /* trunk-ignore(golangci-lint/errorlint) */ + p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true) + return + } + + bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn) + /* trunk-ignore(golangci-lint/errcheck) */ + bdp.Start() +} + +func IsPrint(s string) bool { + for _, r := range s { + if r < ' ' || r > '~' { + return false + } + } + return true +} diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index fde17fd..ce30d01 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -35,9 +35,9 @@ type Options struct { Handler http.Handler } -func StartServer(opt Options) (s *Server) { +func StartServer(parent task.Parent, opt Options) (s *Server) { s = NewServer(opt) - s.Start() + s.Start(parent) return s } @@ -83,11 +83,13 @@ func NewServer(opt Options) (s *Server) { // If both are not set, this does nothing. // // Start() is non-blocking. -func (s *Server) Start() { +func (s *Server) Start(parent task.Parent) { if s.http == nil && s.https == nil { return } + task := parent.Subtask("server."+s.Name, false) + s.startTime = time.Now() if s.http != nil { go func() { @@ -105,7 +107,7 @@ func (s *Server) Start() { s.l.Info().Str("addr", s.https.Addr).Msgf("server started") } - task.OnProgramExit("server."+s.Name+".stop", s.stop) + task.OnCancel("stop", s.stop) } func (s *Server) stop() { @@ -113,14 +115,19 @@ func (s *Server) stop() { return } + ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) + defer cancel() + if s.http != nil && s.httpStarted { - s.handleErr("http", s.http.Shutdown(task.RootContext())) + s.handleErr("http", s.http.Shutdown(ctx)) s.httpStarted = false + s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped") } if s.https != nil && s.httpsStarted { - s.handleErr("https", s.https.Shutdown(task.RootContext())) + s.handleErr("https", s.https.Shutdown(ctx)) s.httpsStarted = false + s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped") } } diff --git a/internal/net/http/status_code.go b/internal/net/http/status_code.go index db8002c..8235805 100644 --- a/internal/net/http/status_code.go +++ b/internal/net/http/status_code.go @@ -5,3 +5,7 @@ import "net/http" func IsSuccess(status int) bool { return status >= http.StatusOK && status < http.StatusMultipleChoices } + +func IsStatusCodeValid(status int) bool { + return http.StatusText(status) != "" +} diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index 1aa00b9..67ca297 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -8,6 +8,11 @@ import ( //nolint:recvcheck type CIDR net.IPNet +func ParseCIDR(v string) (cidr CIDR, err error) { + err = cidr.Parse(v) + return +} + func (cidr *CIDR) Parse(v string) error { if !strings.Contains(v, "/") { v += "/32" // single IP diff --git a/internal/notif/webhook.go b/internal/notif/webhook.go index 74e8293..46d5f23 100644 --- a/internal/notif/webhook.go +++ b/internal/notif/webhook.go @@ -49,10 +49,7 @@ func jsonIfTemplateNotUsed(fl validator.FieldLevel) bool { func init() { utils.RegisterDefaultValueFactory(DefaultValue) - err := utils.Validator().RegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed) - if err != nil { - panic(err) - } + utils.MustRegisterValidation("jsonIfTemplateNotUsed", jsonIfTemplateNotUsed) } // Name implements Provider. diff --git a/internal/route/http.go b/internal/route/http.go index eb60cf2..e5ce760 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -13,6 +13,7 @@ import ( "github.com/yusing/go-proxy/internal/net/http/loadbalancer" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/route/entry" "github.com/yusing/go-proxy/internal/route/routes" route "github.com/yusing/go-proxy/internal/route/types" @@ -30,7 +31,7 @@ type ( loadBalancer *loadbalancer.LoadBalancer server *loadbalancer.Server handler http.Handler - rp *gphttp.ReverseProxy + rp *reverseproxy.ReverseProxy task *task.Task @@ -49,7 +50,7 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) { } service := entry.TargetName() - rp := gphttp.NewReverseProxy(service, entry.URL, trans) + rp := reverseproxy.NewReverseProxy(service, entry.URL, trans) if len(entry.Raw.Middlewares) > 0 { err := middleware.PatchReverseProxy(rp, entry.Raw.Middlewares) @@ -138,7 +139,7 @@ func (r *HTTPRoute) Start(parent task.Parent) E.Error { } if len(r.Raw.Rules) > 0 { - r.handler = r.Raw.Rules.BuildHandler(r.rp) + r.handler = r.Raw.Rules.BuildHandler(r.handler) } if r.HealthMon != nil { diff --git a/internal/route/provider/docker_labels.yaml b/internal/route/provider/docker_labels.yaml index 8065654..8ab6c21 100644 --- a/internal/route/provider/docker_labels.yaml +++ b/internal/route/provider/docker_labels.yaml @@ -72,9 +72,9 @@ proxy.app1.host: 10.0.0.254 proxy.app1.port: 80 proxy.app1.path_patterns: | # Check https://pkg.go.dev/net/http#hdr-Patterns-ServeMux for syntax - GET / # accept any GET request - POST /auth # for /auth and /auth/* accept only POST - GET /home/{$} # for exactly /home + - GET / # accept any GET request + - POST /auth # for /auth and /auth/* accept only POST + - GET /home/{$} # for exactly /home proxy.app1.healthcheck.disabled: false proxy.app1.healthcheck.path: / proxy.app1.healthcheck.interval: 5s diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 46a95b6..45eb1eb 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -5,6 +5,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/entry" + "github.com/yusing/go-proxy/internal/route/provider/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher" ) @@ -87,10 +88,10 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { switch handler.provider.GetType() { - case ProviderTypeDocker: + case types.ProviderTypeDocker: return route.Entry.Container.ContainerID == event.ActorID || route.Entry.Container.ContainerName == event.ActorName - case ProviderTypeFile: + case types.ProviderTypeFile: return true } // should never happen diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 9933bb3..89327da 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -10,6 +10,8 @@ import ( "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" R "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/provider/types" + route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/task" W "github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher/events" @@ -20,7 +22,7 @@ type ( ProviderImpl `json:"-"` name string - t ProviderType + t types.ProviderType routes R.Routes watcher W.Watcher @@ -31,24 +33,20 @@ type ( NewWatcher() W.Watcher Logger() *zerolog.Logger } - ProviderType string ProviderStats struct { - NumRPs int `json:"num_reverse_proxies"` - NumStreams int `json:"num_streams"` - Type ProviderType `json:"type"` + NumRPs int `json:"num_reverse_proxies"` + NumStreams int `json:"num_streams"` + Type types.ProviderType `json:"type"` } ) const ( - ProviderTypeDocker ProviderType = "docker" - ProviderTypeFile ProviderType = "file" - providerEventFlushInterval = 300 * time.Millisecond ) var ErrEmptyProviderName = errors.New("empty provider name") -func newProvider(name string, t ProviderType) *Provider { +func newProvider(name string, t types.ProviderType) *Provider { return &Provider{ name: name, t: t, @@ -61,7 +59,7 @@ func NewFileProvider(filename string) (p *Provider, err error) { if name == "" { return nil, ErrEmptyProviderName } - p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile) + p = newProvider(strings.ReplaceAll(name, ".", "_"), types.ProviderTypeFile) p.ProviderImpl, err = FileProviderImpl(filename) if err != nil { return nil, err @@ -75,7 +73,7 @@ func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) return nil, ErrEmptyProviderName } - p = newProvider(name, ProviderTypeDocker) + p = newProvider(name, types.ProviderTypeDocker) p.ProviderImpl, err = DockerProviderImpl(name, dockerHost, p.IsExplicitOnly()) if err != nil { return nil, err @@ -92,7 +90,7 @@ func (p *Provider) GetName() string { return p.name } -func (p *Provider) GetType() ProviderType { +func (p *Provider) GetType() types.ProviderType { return p.t } @@ -111,7 +109,7 @@ func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error { return nil } -// Start implements*task.TaskStarter. +// Start implements task.TaskStarter. func (p *Provider) Start(parent task.Parent) E.Error { t := parent.Subtask("provider."+p.name, false) @@ -171,9 +169,9 @@ func (p *Provider) Statistics() ProviderStats { numStreams := 0 p.routes.RangeAll(func(_ string, r *R.Route) { switch r.Type { - case R.RouteTypeReverseProxy: + case route.RouteTypeReverseProxy: numRPs++ - case R.RouteTypeStream: + case route.RouteTypeStream: numStreams++ } }) diff --git a/internal/route/provider/types/provider_type.go b/internal/route/provider/types/provider_type.go new file mode 100644 index 0000000..2907762 --- /dev/null +++ b/internal/route/provider/types/provider_type.go @@ -0,0 +1,8 @@ +package types + +type ProviderType string + +const ( + ProviderTypeDocker ProviderType = "docker" + ProviderTypeFile ProviderType = "file" +) diff --git a/internal/route/route.go b/internal/route/route.go index 4101d9a..003b912 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -14,11 +14,10 @@ import ( ) type ( - RouteType string - Route struct { + Route struct { _ U.NoCopy impl - Type RouteType + Type types.RouteType Entry *RawEntry } Routes = F.Map[string, *Route] @@ -34,11 +33,6 @@ type ( RawEntries = types.RawEntries ) -const ( - RouteTypeStream RouteType = "stream" - RouteTypeReverseProxy RouteType = "reverse_proxy" -) - // function alias. var ( NewRoutes = F.NewMap[Routes] @@ -59,15 +53,15 @@ func NewRoute(raw *RawEntry) (*Route, E.Error) { return nil, err } - var t RouteType + var t types.RouteType var rt impl switch e := en.(type) { case *entry.StreamEntry: - t = RouteTypeStream + t = types.RouteTypeStream rt, err = NewStreamRoute(e) case *entry.ReverseProxyEntry: - t = RouteTypeReverseProxy + t = types.RouteTypeReverseProxy rt, err = NewHTTPRoute(e) default: panic("bug: should not reach here") diff --git a/internal/route/routes/query.go b/internal/route/routes/query.go new file mode 100644 index 0000000..cb87ae3 --- /dev/null +++ b/internal/route/routes/query.go @@ -0,0 +1,99 @@ +package routes + +import ( + "strings" + + "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/route/entry" + provider "github.com/yusing/go-proxy/internal/route/provider/types" + "github.com/yusing/go-proxy/internal/route/types" + route "github.com/yusing/go-proxy/internal/route/types" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +func HomepageConfig(useDefaultCategories bool) homepage.Config { + hpCfg := homepage.NewHomePageConfig() + GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { + en := r.RawEntry() + item := en.Homepage + if item == nil { + item = new(homepage.Item) + item.Show = true + } + + if !item.IsEmpty() { + item.Show = true + } + + if !item.Show { + return + } + + item.Alias = alias + + if item.Name == "" { + item.Name = strutils.Title( + strings.ReplaceAll( + strings.ReplaceAll(alias, "-", " "), + "_", " ", + ), + ) + } + + if useDefaultCategories { + if en.Container != nil && item.Category == "" { + if category, ok := homepage.PredefinedCategories[en.Container.ImageName]; ok { + item.Category = category + } + } + + if item.Category == "" { + if category, ok := homepage.PredefinedCategories[strings.ToLower(alias)]; ok { + item.Category = category + } + } + } + + switch { + case entry.IsDocker(r): + if item.Category == "" { + item.Category = "Docker" + } + item.SourceType = string(provider.ProviderTypeDocker) + case entry.UseLoadBalance(r): + if item.Category == "" { + item.Category = "Load-balanced" + } + item.SourceType = "loadbalancer" + default: + if item.Category == "" { + item.Category = "Others" + } + item.SourceType = string(provider.ProviderTypeFile) + } + + item.AltURL = r.TargetURL().String() + hpCfg.Add(item) + }) + return hpCfg +} + +func RoutesByAlias(typeFilter ...route.RouteType) map[string]any { + rts := make(map[string]any) + if len(typeFilter) == 0 || typeFilter[0] == "" { + typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream} + } + for _, t := range typeFilter { + switch t { + case route.RouteTypeReverseProxy: + GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { + rts[alias] = r + }) + case route.RouteTypeStream: + GetStreamRoutes().RangeAll(func(alias string, r types.StreamRoute) { + rts[alias] = r + }) + } + } + return rts +} diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go new file mode 100644 index 0000000..5d93fd5 --- /dev/null +++ b/internal/route/rules/do.go @@ -0,0 +1,248 @@ +package rules + +import ( + "net/http" + "path" + "strconv" + "strings" + + E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type ( + Command struct { + raw string + exec *CommandExecutor + } + CommandExecutor struct { + directive string + http.HandlerFunc + proceed bool + } +) + +const ( + CommandRewrite = "rewrite" + CommandServe = "serve" + CommandProxy = "proxy" + CommandRedirect = "redirect" + CommandError = "error" + CommandBypass = "bypass" +) + +var commands = map[string]struct { + help Help + validate ValidateFunc + build func(args any) *CommandExecutor +}{ + CommandRewrite: { + help: Help{ + command: CommandRewrite, + args: map[string]string{ + "from": "the path to rewrite, must start with /", + "to": "the path to rewrite to, must start with /", + }, + }, + validate: func(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + return validateURLPaths(args) + }, + build: func(args any) *CommandExecutor { + a := args.([]string) + orig, repl := a[0], a[1] + return &CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + if len(path) > 0 && path[0] != '/' { + path = "/" + path + } + if !strings.HasPrefix(path, orig) { + return + } + path = repl + path[len(orig):] + r.URL.Path = path + r.URL.RawPath = r.URL.EscapedPath() + r.RequestURI = r.URL.RequestURI() + }, + proceed: true, + } + }, + }, + CommandServe: { + help: Help{ + command: CommandServe, + args: map[string]string{ + "root": "the file system path to serve, must be an existing directory", + }, + }, + validate: validateFSPath, + build: func(args any) *CommandExecutor { + root := args.(string) + return &CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) + }, + proceed: false, + } + }, + }, + CommandRedirect: { + help: Help{ + command: CommandRedirect, + args: map[string]string{ + "to": "the url to redirect to, can be relative or absolute URL", + }, + }, + validate: validateURL, + build: func(args any) *CommandExecutor { + target := args.(types.URL).String() + return &CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target, http.StatusTemporaryRedirect) + }, + proceed: false, + } + }, + }, + CommandError: { + help: Help{ + command: CommandError, + args: map[string]string{ + "code": "the http status code to return", + "text": "the error message to return", + }, + }, + validate: func(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + codeStr, text := args[0], args[1] + code, err := strconv.Atoi(codeStr) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + if !gphttp.IsStatusCodeValid(code) { + return nil, ErrInvalidArguments.Subject(codeStr) + } + return []any{code, text}, nil + }, + build: func(args any) *CommandExecutor { + a := args.([]any) + code, text := a[0].(int), a[1].(string) + return &CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, text, code) + }, + proceed: false, + } + }, + }, + CommandProxy: { + help: Help{ + command: CommandProxy, + args: map[string]string{ + "to": "the url to proxy to, must be an absolute URL", + }, + }, + validate: validateAbsoluteURL, + build: func(args any) *CommandExecutor { + target := args.(types.URL) + if target.Scheme == "" { + target.Scheme = "http" + } + rp := reverseproxy.NewReverseProxy("", target, gphttp.DefaultTransport) + return &CommandExecutor{ + HandlerFunc: rp.ServeHTTP, + proceed: false, + } + }, + }, +} + +// Parse implements strutils.Parser. +func (cmd *Command) Parse(v string) error { + cmd.raw = v + + lines := strutils.SplitLine(v) + if len(lines) == 0 { + return nil + } + + executors := make([]*CommandExecutor, 0, len(lines)) + for _, line := range lines { + if line == "" { + continue + } + + directive, args, err := parse(line) + if err != nil { + return err + } + + if directive == CommandBypass { + if len(args) != 0 { + return ErrInvalidArguments.Subject(directive) + } + return nil + } + + builder, ok := commands[directive] + if !ok { + return ErrUnknownDirective.Subject(directive) + } + validArgs, err := builder.validate(args) + if err != nil { + return err.Subject(directive).Withf("%s", builder.help.String()) + } + + exec := builder.build(validArgs) + exec.directive = directive + executors = append(executors, exec) + } + + exec, err := buildCmd(executors) + if err != nil { + return err + } + cmd.exec = exec + return nil +} + +func buildCmd(executors []*CommandExecutor) (*CommandExecutor, error) { + for i, exec := range executors { + if !exec.proceed && i != len(executors)-1 { + return nil, ErrInvalidCommandSequence. + Withf("%s cannot follow %s", exec, executors[i+1]) + } + } + return &CommandExecutor{ + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + for _, exec := range executors { + exec.HandlerFunc(w, r) + } + }, + proceed: executors[len(executors)-1].proceed, + }, nil +} + +func (cmd *Command) isBypass() bool { + return cmd.exec == nil +} + +func (cmd *Command) String() string { + return cmd.raw +} + +func (cmd *Command) MarshalJSON() ([]byte, error) { + return []byte("\"" + cmd.String() + "\""), nil +} + +func (exec *CommandExecutor) String() string { + return exec.directive +} diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go new file mode 100644 index 0000000..ad79fb5 --- /dev/null +++ b/internal/route/rules/errors.go @@ -0,0 +1,15 @@ +package rules + +import E "github.com/yusing/go-proxy/internal/error" + +var ( + ErrUnterminatedQuotes = E.New("unterminated quotes") + ErrUnsupportedEscapeChar = E.New("unsupported escape char") + ErrUnknownDirective = E.New("unknown directive") + ErrInvalidArguments = E.New("invalid arguments") + ErrInvalidOnTarget = E.New("invalid `rule.on` target") + ErrInvalidCommandSequence = E.New("invalid command sequence") + + ErrExpectOneArg = ErrInvalidArguments.Withf("expect 1 arg") + ErrExpectTwoArgs = ErrInvalidArguments.Withf("expect 2 args") +) diff --git a/internal/route/rules/help.go b/internal/route/rules/help.go new file mode 100644 index 0000000..cdff5c9 --- /dev/null +++ b/internal/route/rules/help.go @@ -0,0 +1,41 @@ +package rules + +import "strings" + +type Help struct { + command string + description string + args map[string]string // args[arg] -> description +} + +/* +Generate help string, e.g. + + rewrite + from: the path to rewrite, must start with / + to: the path to rewrite to, must start with / +*/ +func (h *Help) String() string { + var sb strings.Builder + sb.WriteString(h.command) + sb.WriteString(" ") + for arg := range h.args { + sb.WriteRune('<') + sb.WriteString(arg) + sb.WriteString("> ") + } + if h.description != "" { + sb.WriteString("\n\t") + sb.WriteString(h.description) + sb.WriteRune('\n') + } + sb.WriteRune('\n') + for arg, desc := range h.args { + sb.WriteRune('\t') + sb.WriteString(arg) + sb.WriteString(": ") + sb.WriteString(desc) + sb.WriteRune('\n') + } + return sb.String() +} diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go new file mode 100644 index 0000000..5f9fccd --- /dev/null +++ b/internal/route/rules/on.go @@ -0,0 +1,254 @@ +package rules + +import ( + "net" + "net/http" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +type ( + RuleOn struct { + raw string + check CheckFulfill + } + CheckFulfill func(r *http.Request) bool + Checkers []CheckFulfill +) + +const ( + OnHeader = "header" + OnQuery = "query" + OnCookie = "cookie" + OnForm = "form" + OnPostForm = "postform" + OnMethod = "method" + OnPath = "path" + OnRemote = "remote" +) + +var checkers = map[string]struct { + help Help + validate ValidateFunc + check func(r *http.Request, args any) bool +}{ + OnHeader: { + help: Help{ + command: OnHeader, + args: map[string]string{ + "key": "the header key", + "value": "the header value", + }, + }, + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + return r.Header.Get(args.(StrTuple).First) == args.(StrTuple).Second + }, + }, + OnQuery: { + help: Help{ + command: OnQuery, + args: map[string]string{ + "key": "the query key", + "value": "the query value", + }, + }, + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + return r.URL.Query().Get(args.(StrTuple).First) == args.(StrTuple).Second + }, + }, + OnCookie: { + help: Help{ + command: OnCookie, + args: map[string]string{ + "key": "the cookie key", + "value": "the cookie value", + }, + }, + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + cookies := r.CookiesNamed(args.(StrTuple).First) + for _, cookie := range cookies { + if cookie.Value == args.(StrTuple).Second { + return true + } + } + return false + }, + }, + OnForm: { + help: Help{ + command: OnForm, + args: map[string]string{ + "key": "the form key", + "value": "the form value", + }, + }, + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + return r.FormValue(args.(StrTuple).First) == args.(StrTuple).Second + }, + }, + OnPostForm: { + help: Help{ + command: OnPostForm, + args: map[string]string{ + "key": "the form key", + "value": "the form value", + }, + }, + validate: toStrTuple, + check: func(r *http.Request, args any) bool { + return r.PostFormValue(args.(StrTuple).First) == args.(StrTuple).Second + }, + }, + OnMethod: { + help: Help{ + command: OnMethod, + args: map[string]string{ + "method": "the http method", + }, + }, + validate: validateMethod, + check: func(r *http.Request, method any) bool { + return r.Method == method.(string) + }, + }, + OnPath: { + help: Help{ + command: OnPath, + description: `The path can be a glob pattern, e.g.: + /path/to + /path/to/*`, + args: map[string]string{ + "path": "the request path, must start with /", + }, + }, + validate: validateURLPath, + check: func(r *http.Request, globPath any) bool { + reqPath := r.URL.Path + if len(reqPath) > 0 && reqPath[0] != '/' { + reqPath = "/" + reqPath + } + return strutils.GlobMatch(globPath.(string), reqPath) + }, + }, + OnRemote: { + help: Help{ + command: OnRemote, + args: map[string]string{ + "ip|cidr": "the remote ip or cidr", + }, + }, + validate: validateCIDR, + check: func(r *http.Request, cidr any) bool { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return cidr.(*net.IPNet).Contains(ip) + }, + }, +} + +// Parse implements strutils.Parser. +func (on *RuleOn) Parse(v string) error { + on.raw = v + + lines := strutils.SplitLine(v) + checks := make(Checkers, 0, len(lines)) + + errs := E.NewBuilder("rule.on syntax errors") + for i, line := range lines { + if line == "" { + continue + } + parsed, err := parseOn(line) + if err != nil { + errs.Add(err.Subjectf("line %d", i+1)) + continue + } + checks = append(checks, parsed.matchOne()) + } + + on.check = checks.matchAll() + return errs.Error() +} + +func (on *RuleOn) String() string { + return on.raw +} + +func (on *RuleOn) MarshalJSON() ([]byte, error) { + return []byte("\"" + on.String() + "\""), nil +} + +func parseOn(line string) (Checkers, E.Error) { + ors := strutils.SplitRune(line, '|') + + if len(ors) > 1 { + errs := E.NewBuilder("rule.on syntax errors") + checks := make([]CheckFulfill, len(ors)) + for i, or := range ors { + curCheckers, err := parseOn(or) + if err != nil { + errs.Add(err) + continue + } + checks[i] = curCheckers[0] + } + if err := errs.Error(); err != nil { + return nil, err + } + return checks, nil + } + + subject, args, err := parse(line) + if err != nil { + return nil, err + } + + checker, ok := checkers[subject] + if !ok { + return nil, ErrInvalidOnTarget.Subject(subject) + } + + validArgs, err := checker.validate(args) + if err != nil { + return nil, err.Subject(subject).Withf("%s", checker.help.String()) + } + + return Checkers{ + func(r *http.Request) bool { + return checker.check(r, validArgs) + }, + }, nil +} + +func (checkers Checkers) matchOne() CheckFulfill { + return func(r *http.Request) bool { + for _, checker := range checkers { + if checker(r) { + return true + } + } + return false + } +} + +func (checkers Checkers) matchAll() CheckFulfill { + return func(r *http.Request) bool { + for _, checker := range checkers { + if !checker(r) { + return false + } + } + return true + } +} diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go new file mode 100644 index 0000000..ac51ebf --- /dev/null +++ b/internal/route/rules/parser.go @@ -0,0 +1,79 @@ +package rules + +import ( + "strings" + + E "github.com/yusing/go-proxy/internal/error" +) + +var escapedChars = map[rune]rune{ + 'n': '\n', + 't': '\t', + 'r': '\r', + '\'': '\'', + '"': '"', + '\\': '\\', + ' ': ' ', +} + +// parse expression to subject and args +// with support for quotes and escaped chars, e.g. +// +// error 403 "Forbidden 'foo' 'bar'" +// error 403 Forbidden\ \"foo\"\ \"bar\". +func parse(v string) (subject string, args []string, err E.Error) { + v = strings.TrimSpace(v) + var buf strings.Builder + escaped := false + quotes := make([]rune, 0, 4) + flush := func() { + if subject == "" { + subject = buf.String() + } else { + args = append(args, buf.String()) + } + buf.Reset() + } + for _, r := range v { + if escaped { + if ch, ok := escapedChars[r]; ok { + buf.WriteRune(ch) + } else { + err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r) + return + } + escaped = false + continue + } + switch r { + case '\\': + escaped = true + continue + case '"', '\'': + switch { + case len(quotes) > 0 && quotes[len(quotes)-1] == r: + quotes = quotes[:len(quotes)-1] + if len(quotes) == 0 { + flush() + } else { + buf.WriteRune(r) + } + case len(quotes) == 0: + quotes = append(quotes, r) + default: + buf.WriteRune(r) + } + case ' ': + flush() + default: + buf.WriteRune(r) + } + } + + if len(quotes) > 0 { + err = ErrUnterminatedQuotes + } else { + flush() + } + return +} diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go new file mode 100644 index 0000000..1908eb1 --- /dev/null +++ b/internal/route/rules/rules.go @@ -0,0 +1,103 @@ +package rules + +import ( + "net/http" +) + +type ( + /* + Example: + + proxy.app1.rules: | + - name: default + do: | + rewrite / /index.html + serve /var/www/goaccess + - name: ws + on: | + header Connection Upgrade + header Upgrade websocket + do: bypass + + proxy.app2.rules: | + - name: default + do: bypass + - name: block POST and PUT + on: method POST | method PUT + do: error 403 Forbidden + */ + Rules []Rule + /* + Rule is a rule for a reverse proxy. + It do `Do` when `On` matches. + + A rule can have multiple lines of on. + + All lines of on must match, + but each line can have multiple checks that + one match means this line is matched. + */ + Rule struct { + Name string `json:"name" validate:"required"` + On RuleOn `json:"on"` + Do Command `json:"do"` + } +) + +// BuildHandler returns a http.HandlerFunc that implements the rules. +// +// if a bypass rule matches, +// the request is passed to the upstream and no more rules are executed. +// +// if no rule matches, the default rule is executed +// if no rule matches and default rule is not set, +// the request is passed to the upstream. +func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { + var ( + defaultRule Rule + defaultRuleIndex int + ) + + for i, rule := range rules { + if rule.Name == "default" { + defaultRule = rule + defaultRuleIndex = i + break + } + } + + rules = append(rules[:defaultRuleIndex], rules[defaultRuleIndex+1:]...) + + // free allocated empty slices + // before encapsulating them into the handlerFunc. + if len(rules) == 0 { + if defaultRule.Do.isBypass() { + return up.ServeHTTP + } + rules = []Rule{} + } + + return func(w http.ResponseWriter, r *http.Request) { + hasMatch := false + for _, rule := range rules { + if rule.On.check(r) { + if rule.Do.isBypass() { + up.ServeHTTP(w, r) + return + } + rule.Do.exec.HandlerFunc(w, r) + if !rule.Do.exec.proceed { + return + } + hasMatch = true + } + } + + if hasMatch || defaultRule.Do.isBypass() { + up.ServeHTTP(w, r) + return + } + + defaultRule.Do.exec.HandlerFunc(w, r) + } +} diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go new file mode 100644 index 0000000..dd2b662 --- /dev/null +++ b/internal/route/rules/rules_test.go @@ -0,0 +1,251 @@ +package rules + +import ( + "testing" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestParseSubjectArgs(t *testing.T) { + t.Run("basic", func(t *testing.T) { + subject, args, err := parse("rewrite / /foo/bar") + ExpectNoError(t, err) + ExpectEqual(t, subject, "rewrite") + ExpectDeepEqual(t, args, []string{"/", "/foo/bar"}) + }) + t.Run("with quotes", func(t *testing.T) { + subject, args, err := parse(`error 403 "Forbidden 'foo' 'bar'."`) + ExpectNoError(t, err) + ExpectEqual(t, subject, "error") + ExpectDeepEqual(t, args, []string{"403", "Forbidden 'foo' 'bar'."}) + }) + t.Run("with escaped", func(t *testing.T) { + subject, args, err := parse(`error 403 Forbidden\ \"foo\"\ \"bar\".`) + ExpectNoError(t, err) + ExpectEqual(t, subject, "error") + ExpectDeepEqual(t, args, []string{"403", "Forbidden \"foo\" \"bar\"."}) + }) +} + +func TestParseCommands(t *testing.T) { + tests := []struct { + name string + input string + wantErr error + }{ + // bypass tests + { + name: "bypass_valid", + input: "bypass", + wantErr: nil, + }, + { + name: "bypass_invalid_with_args", + input: "bypass /", + wantErr: ErrInvalidArguments, + }, + // rewrite tests + { + name: "rewrite_valid", + input: "rewrite / /foo/bar", + wantErr: nil, + }, + { + name: "rewrite_missing_target", + input: "rewrite /", + wantErr: ErrInvalidArguments, + }, + { + name: "rewrite_too_many_args", + input: "rewrite / / /", + wantErr: ErrInvalidArguments, + }, + { + name: "rewrite_no_leading_slash", + input: "rewrite abc /", + wantErr: ErrInvalidArguments, + }, + // serve tests + { + name: "serve_valid", + input: "serve /var/www", + wantErr: nil, + }, + { + name: "serve_missing_path", + input: "serve ", + wantErr: ErrInvalidArguments, + }, + { + name: "serve_too_many_args", + input: "serve / / /", + wantErr: ErrInvalidArguments, + }, + // redirect tests + { + name: "redirect_valid", + input: "redirect /", + wantErr: nil, + }, + { + name: "redirect_too_many_args", + input: "redirect / /", + wantErr: ErrInvalidArguments, + }, + // error directive tests + { + name: "error_valid", + input: "error 404 Not\\ Found", + wantErr: nil, + }, + { + name: "error_missing_status_code", + input: "error Not\\ Found", + wantErr: ErrInvalidArguments, + }, + { + name: "error_too_many_args", + input: "error 404 Not\\ Found extra", + wantErr: ErrInvalidArguments, + }, + { + name: "error_no_escaped_space", + input: "error 404 Not Found", + wantErr: ErrInvalidArguments, + }, + { + name: "error_invalid_status_code", + input: "error 123 abc", + wantErr: ErrInvalidArguments, + }, + // proxy directive tests + { + name: "proxy_valid", + input: "proxy localhost:8080", + wantErr: nil, + }, + { + name: "proxy_missing_target", + input: "proxy", + wantErr: ErrInvalidArguments, + }, + { + name: "proxy_too_many_args", + input: "proxy localhost:8080 extra", + wantErr: ErrInvalidArguments, + }, + { + name: "proxy_invalid_url", + input: "proxy :invalid_url", + wantErr: ErrInvalidArguments, + }, + // unknown directive test + { + name: "unknown_directive", + input: "unknown /", + wantErr: ErrUnknownDirective, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := Command{} + err := cmd.Parse(tt.input) + if tt.wantErr != nil { + ExpectError(t, tt.wantErr, err) + } else { + ExpectNoError(t, err) + } + }) + } +} + +func TestParseOn(t *testing.T) { + tests := []struct { + name string + input string + wantErr E.Error + }{ + // header + { + name: "header_valid", + input: "header Connection Upgrade", + wantErr: nil, + }, + { + name: "header_invalid", + input: "header Connection", + wantErr: ErrInvalidArguments, + }, + // query + { + name: "query_valid", + input: "query key value", + wantErr: nil, + }, + { + name: "query_invalid", + input: "query key", + wantErr: ErrInvalidArguments, + }, + // method + { + name: "method_valid", + input: "method GET", + wantErr: nil, + }, + { + name: "method_invalid", + input: "method", + wantErr: ErrInvalidArguments, + }, + // path + { + name: "path_valid", + input: "path /home", + wantErr: nil, + }, + { + name: "path_invalid", + input: "path", + wantErr: ErrInvalidArguments, + }, + // remote + { + name: "remote_valid", + input: "remote 127.0.0.1", + wantErr: nil, + }, + { + name: "remote_invalid", + input: "remote", + wantErr: ErrInvalidArguments, + }, + { + name: "unknown_target", + input: "unknown", + wantErr: ErrInvalidOnTarget, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + on := &RuleOn{} + err := on.Parse(tt.input) + if tt.wantErr != nil { + ExpectError(t, tt.wantErr, err) + } else { + ExpectNoError(t, err) + } + }) + } +} + +func TestParseRule(t *testing.T) { + // test := map[string]any{ + // "name": "test", + // "on": "method GET", + // "do": "bypass", + // } +} diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go new file mode 100644 index 0000000..55621fb --- /dev/null +++ b/internal/route/rules/validate.go @@ -0,0 +1,125 @@ +package rules + +import ( + "os" + "path" + "strings" + + E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/types" +) + +type ( + ValidateFunc func(args []string) (any, E.Error) + StrTuple struct { + First, Second string + } +) + +func toStrTuple(args []string) (any, E.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + return StrTuple{args[0], args[1]}, nil +} + +func validateURL(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + u, err := types.ParseURL(args[0]) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + return u, nil +} + +func validateAbsoluteURL(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + u, err := types.ParseURL(args[0]) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + if u.Scheme == "" { + u.Scheme = "http" + } + if u.Host == "" { + return nil, ErrInvalidArguments.Withf("missing host") + } + return u, nil +} + +func validateCIDR(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + if !strings.Contains(args[0], "/") { + args[0] += "/32" + } + cidr, err := types.ParseCIDR(args[0]) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + return cidr, nil +} + +func validateURLPath(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + p := args[0] + trailingSlash := len(p) > 1 && p[len(p)-1] == '/' + p, _, _ = strings.Cut(p, "#") + p = path.Clean(p) + if len(p) == 0 { + return nil, ErrInvalidArguments.Withf("empty path") + } + if trailingSlash { + p += "/" + } + if p[0] != '/' { + return nil, ErrInvalidArguments.Withf("must start with /") + } + return p, nil +} + +func validateURLPaths(paths []string) (any, E.Error) { + errs := E.NewBuilder("invalid url paths") + for i, p := range paths { + val, err := validateURLPath([]string{p}) + if err != nil { + errs.Add(err.Subject(p)) + continue + } + paths[i] = val.(string) + } + if err := errs.Error(); err != nil { + return nil, err + } + return paths, nil +} + +func validateFSPath(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + p := path.Clean(args[0]) + if _, err := os.Stat(p); err != nil { + return nil, ErrInvalidArguments.With(err) + } + return p, nil +} + +func validateMethod(args []string) (any, E.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + method := strings.ToUpper(args[0]) + if !gphttp.IsMethodValid(method) { + return nil, ErrInvalidArguments.Subject(method) + } + return method, nil +} diff --git a/internal/route/types/raw_entry.go b/internal/route/types/raw_entry.go index 92a77b0..b3a1479 100644 --- a/internal/route/types/raw_entry.go +++ b/internal/route/types/raw_entry.go @@ -12,6 +12,7 @@ import ( "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/http/accesslog" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" + "github.com/yusing/go-proxy/internal/route/rules" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -30,7 +31,7 @@ type ( Port string `json:"port,omitempty"` NoTLSVerify bool `json:"no_tls_verify,omitempty"` PathPatterns []string `json:"path_patterns,omitempty"` - Rules Rules `json:"rules,omitempty"` + Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name"` HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` LoadBalance *loadbalance.Config `json:"load_balance,omitempty"` Middlewares map[string]docker.LabelMap `json:"middlewares,omitempty"` diff --git a/internal/route/types/route_type.go b/internal/route/types/route_type.go new file mode 100644 index 0000000..f5357db --- /dev/null +++ b/internal/route/types/route_type.go @@ -0,0 +1,8 @@ +package types + +type RouteType string + +const ( + RouteTypeStream RouteType = "stream" + RouteTypeReverseProxy RouteType = "reverse_proxy" +) diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 8fd7654..79affb1 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -9,7 +9,6 @@ import ( "strconv" "strings" "time" - "unicode" "github.com/go-playground/validator/v10" E "github.com/yusing/go-proxy/internal/error" @@ -48,82 +47,6 @@ func New(t reflect.Type) reflect.Value { return reflect.New(t) } -// Serialize converts the given data into a map[string]any representation. -// -// It uses reflection to inspect the data type and handle different kinds of data. -// For a struct, it extracts the fields using the json tag if present, or the field name if not. -// For an embedded struct, it recursively converts its fields into the result map. -// For any other type, it returns an error. -// -// Parameters: -// - data: The data to be converted into a map. -// -// Returns: -// - result: The resulting map[string]any representation of the data. -// - error: An error if the data type is unsupported or if there is an error during conversion. -func Serialize(data any) (SerializedObject, error) { - result := make(map[string]any) - - // Use reflection to inspect the data type - value := reflect.ValueOf(data) - - // Check if the value is valid - if !value.IsValid() { - return nil, ErrInvalidType.Subjectf("%T", data) - } - - // Dereference pointers if necessary - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - - // Handle different kinds of data - switch value.Kind() { - case reflect.Map: - for _, key := range value.MapKeys() { - result[key.String()] = value.MapIndex(key).Interface() - } - case reflect.Struct: - for i := range value.NumField() { - field := value.Type().Field(i) - if !field.IsExported() { - continue - } - jsonTag := field.Tag.Get("json") // Get the json tag - if jsonTag == "-" { - continue // Ignore this field if the tag is "-" - } - if strings.Contains(jsonTag, ",omitempty") { - if value.Field(i).IsZero() { - continue - } - jsonTag = strings.Replace(jsonTag, ",omitempty", "", 1) - } - - // If the json tag is not empty, use it as the key - switch { - case jsonTag != "": - result[jsonTag] = value.Field(i).Interface() - case field.Anonymous: - // If the field is an embedded struct, add its fields to the result - fieldMap, err := Serialize(value.Field(i).Interface()) - if err != nil { - return nil, err - } - for k, v := range fieldMap { - result[k] = v - } - default: - result[field.Name] = value.Field(i).Interface() - } - } - default: - return nil, errors.New("serialize: unsupported data type " + value.Kind().String()) - } - - return result, nil -} - func extractFields(t reflect.Type) []reflect.StructField { for t.Kind() == reflect.Ptr { t = t.Elem() @@ -203,9 +126,8 @@ func Deserialize(src SerializedObject, dst any) E.Error { mapping[key] = dstV.FieldByName(field.Name) fieldName[field.Name] = key - _, ok := field.Tag.Lookup("validate") - if ok { - needValidate = true + if !needValidate { + _, needValidate = field.Tag.Lookup("validate") } aliases, ok := field.Tag.Lookup("aliases") @@ -258,7 +180,7 @@ func Deserialize(src SerializedObject, dst any) E.Error { } return errs.Error() default: - return ErrUnsupportedConversion.Subject("deserialize to " + dstT.String()) + return ErrUnsupportedConversion.Subject("mapping to " + dstT.String()) } } @@ -355,7 +277,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { if dstT.Kind() != reflect.Slice { return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String()) } - newSlice := reflect.MakeSlice(dstT, 0, src.Len()) + newSlice := reflect.MakeSlice(dstT, src.Len(), src.Len()) i := 0 for _, v := range src.Seq2() { tmp := New(dstT.Elem()).Elem() @@ -363,7 +285,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { if err != nil { return err.Subjectf("[%d]", i) } - newSlice = reflect.Append(newSlice, tmp) + newSlice.Index(i).Set(tmp) i++ } dst.Set(newSlice) @@ -424,10 +346,11 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E return true, E.From(parser.Parse(src)) } // yaml like - isMultiline := strings.ContainsRune(src, '\n') var tmp any switch dst.Kind() { case reflect.Slice: + src = strings.TrimSpace(src) + isMultiline := strings.ContainsRune(src, '\n') // one liner is comma separated list if !isMultiline { values := strutils.CommaSeperatedList(src) @@ -444,16 +367,10 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E } return } - lines := strutils.SplitLine(src) - sl := make([]string, 0, len(lines)) - for _, line := range lines { - line = strings.TrimLeftFunc(line, func(r rune) bool { - return r == '-' || unicode.IsSpace(r) - }) - if line == "" || line[0] == '#' { - continue - } - sl = append(sl, line) + sl := make([]any, 0) + err := yaml.Unmarshal([]byte(src), &sl) + if err != nil { + return true, E.From(err) } tmp = sl case reflect.Map, reflect.Struct: diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index abca97c..b45d95c 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -8,7 +8,7 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestSerializeDeserialize(t *testing.T) { +func TestDeserialize(t *testing.T) { type S struct { I int S string @@ -37,12 +37,6 @@ func TestSerializeDeserialize(t *testing.T) { } ) - t.Run("serialize", func(t *testing.T) { - s, err := Serialize(testStruct) - ExpectNoError(t, err) - ExpectDeepEqual(t, s, testStructSerialized) - }) - t.Run("deserialize", func(t *testing.T) { var s2 S err := Deserialize(testStructSerialized, &s2) @@ -174,7 +168,7 @@ func TestStringToSlice(t *testing.T) { }) t.Run("multiline", func(t *testing.T) { dst := make([]string, 0) - convertible, err := ConvertString(" a\n b\n c", reflect.ValueOf(&dst)) + convertible, err := ConvertString("- a\n- b\n- c", reflect.ValueOf(&dst)) ExpectTrue(t, convertible) ExpectNoError(t, err) ExpectDeepEqual(t, dst, []string{"a", "b", "c"}) diff --git a/internal/utils/strutils/string.go b/internal/utils/strutils/string.go index 03c47c8..4664c2b 100644 --- a/internal/utils/strutils/string.go +++ b/internal/utils/strutils/string.go @@ -1,7 +1,6 @@ package strutils import ( - "net/url" "strings" "golang.org/x/text/cases" @@ -22,14 +21,6 @@ func Title(s string) string { return cases.Title(language.AmericanEnglish).String(s) } -func ExtractPort(fullURL string) (int, error) { - url, err := url.Parse(fullURL) - if err != nil { - return 0, err - } - return Atoi(url.Port()) -} - func ToLowerNoSnake(s string) string { return strings.ToLower(strings.ReplaceAll(s, "_", "")) } diff --git a/internal/utils/validation.go b/internal/utils/validation.go index f27c44c..226657a 100644 --- a/internal/utils/validation.go +++ b/internal/utils/validation.go @@ -12,3 +12,10 @@ var ErrValidationError = E.New("validation error") func Validator() *validator.Validate { return validate } + +func MustRegisterValidation(tag string, fn validator.Func) { + err := validate.RegisterValidation(tag, fn) + if err != nil { + panic(err) + } +} diff --git a/next-release.md b/next-release.md new file mode 100644 index 0000000..bf902e9 --- /dev/null +++ b/next-release.md @@ -0,0 +1,118 @@ +GoDoxy v0.8.2 expected changes + +- **Thanks [polds](https://github.com/polds)** + Optionally allow a user to specify a “warm-up” endpoint to start the container, returning a 403 if the endpoint isn’t hit and the container has been stopped. + + This can help prevent bots from starting random containers, or allow health check systems to run some probes. Or potentially lock the start endpoints behind a different authentication mechanism, etc. + + Sample service showing this: + + ```yaml + hello-world: + image: nginxdemos/hello + container_name: hello-world + restart: "no" + ports: + - "9100:80" + labels: + proxy.aliases: hello-world + proxy.#1.port: 9100 + proxy.idle_timeout: 45s + proxy.wake_timeout: 30s + proxy.stop_method: stop + proxy.stop_timeout: 10s + proxy.stop_signal: SIGTERM + proxy.start_endpoint: "/start" + ``` + + Hitting `/` on this service when the container is down: + + ```curl + $ curl -sv -X GET -H "Host: hello-world.godoxy.local" http://localhost/ + * Host localhost:80 was resolved. + * IPv6: ::1 + * IPv4: 127.0.0.1 + * Trying [::1]:80... + * Connected to localhost (::1) port 80 + > GET / HTTP/1.1 + > Host: hello-world.godoxy.local + > User-Agent: curl/8.7.1 + > Accept: */* + > + * Request completely sent off + < HTTP/1.1 403 Forbidden + < Content-Type: text/plain; charset=utf-8 + < X-Content-Type-Options: nosniff + < Date: Wed, 08 Jan 2025 02:04:51 GMT + < Content-Length: 71 + < + Forbidden: Container can only be started via configured start endpoint + * Connection #0 to host localhost left intact + ``` + + Hitting `/start` when the container is down: + + ```curl + curl -sv -X GET -H "Host: hello-world.godoxy.local" -H "X-Goproxy-Check-Redirect: skip" http://localhost/start + * Host localhost:80 was resolved. + * IPv6: ::1 + * IPv4: 127.0.0.1 + * Trying [::1]:80... + * Connected to localhost (::1) port 80 + > GET /start HTTP/1.1 + > Host: hello-world.godoxy.local + > User-Agent: curl/8.7.1 + > Accept: */* + > X-Goproxy-Check-Redirect: skip + > + * Request completely sent off + < HTTP/1.1 200 OK + < Date: Wed, 08 Jan 2025 02:13:39 GMT + < Content-Length: 0 + < + * Connection #0 to host localhost left intact + ``` + +- Caddyfile like rules + + ```yaml + proxy.goaccess.rules: | + - name: default + do: | + rewrite / /index.html + serve /var/www/goaccess + - name: ws + on: | + header Connection Upgrade + header Upgrade websocket + do: bypass # do nothing, pass to reverse proxy + + proxy.app.rules: | + - name: default + do: bypass # do nothing, pass to reverse proxy + - name: block POST and PUT + on: method POST | method PUT + do: error 403 Forbidden + ``` + +```` + +- config reload will now cause all servers to fully restart (i.e. proxy, api, prometheus, etc) +- multiline-string as list now treated as YAML list, which requires hyphen prefix `-`, i.e. + ```yaml + proxy.app.middlewares.request.hide_headers: + - X-Header1 + - X-Header2 +```` +- autocert now supports hot-reload +- middleware compose now supports cross-referencing, e.g. + ```yaml + foo: + - use: RedirectHTTP + bar: # in the same file or different file + - use: foo@file + ``` + +- Fixes + - bug: cert renewal failure no longer causes renew schdueler to stuck forever + - bug: access log writes to closed file after config reload