diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 2719443..66cd116 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -23,10 +23,10 @@ lint: enabled: - hadolint@2.12.1-beta - actionlint@1.7.6 - - checkov@3.2.347 + - checkov@3.2.350 - git-diff-check - gofmt@1.20.4 - - golangci-lint@1.62.2 + - golangci-lint@1.63.4 - osv-scanner@1.9.2 - oxipng@9.1.3 - prettier@3.4.2 diff --git a/cmd/main.go b/cmd/main.go index ec79e13..152336c 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,23 +3,19 @@ package main import ( "encoding/json" "log" - "net/http" "os" "os/signal" "syscall" "time" "github.com/yusing/go-proxy/internal" - "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" - "github.com/yusing/go-proxy/internal/entrypoint" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/metrics" "github.com/yusing/go-proxy/internal/net/http/middleware" - "github.com/yusing/go-proxy/internal/net/http/server" + "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/pkg" ) @@ -97,16 +93,16 @@ func main() { switch args.Command { case common.CommandListRoutes: cfg.StartProxyProviders() - printJSON(config.RoutesByAlias()) + printJSON(routes.RoutesByAlias()) return case common.CommandListConfigs: - printJSON(config.Value()) + printJSON(cfg.Value()) return case common.CommandDebugListEntries: - printJSON(config.DumpEntries()) + printJSON(cfg.DumpEntries()) return case common.CommandDebugListProviders: - printJSON(config.DumpProviders()) + printJSON(cfg.DumpProviders()) return } @@ -114,7 +110,7 @@ func main() { logging.Warn().Msg("API JWT secret is empty, authentication is disabled") } - cfg.StartProxyProviders() + cfg.Start() config.WatchChanges() sig := make(chan os.Signal, 1) @@ -122,44 +118,12 @@ func main() { signal.Notify(sig, syscall.SIGTERM) signal.Notify(sig, syscall.SIGHUP) - autocert := config.GetAutoCertProvider() - if autocert != nil { - if err := autocert.Setup(); err != nil { - E.LogFatal("autocert setup error", err) - } - } else { - logging.Info().Msg("autocert not configured") - } - - server.StartServer(server.Options{ - Name: "proxy", - CertProvider: autocert, - HTTPAddr: common.ProxyHTTPAddr, - HTTPSAddr: common.ProxyHTTPSAddr, - Handler: http.HandlerFunc(entrypoint.Handler), - }) - server.StartServer(server.Options{ - Name: "api", - CertProvider: autocert, - HTTPAddr: common.APIHTTPAddr, - Handler: api.NewHandler(), - }) - - if common.PrometheusEnabled { - server.StartServer(server.Options{ - Name: "metrics", - CertProvider: autocert, - HTTPAddr: common.MetricsHTTPAddr, - Handler: metrics.NewHandler(), - }) - } - // wait for signal <-sig // grafully shutdown logging.Info().Msg("shutting down") - _ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown)) + _ = task.GracefulShutdown(time.Second * time.Duration(cfg.Value().TimeoutShutdown)) } func prepareDirectory(dir string) { diff --git a/internal/api/handler.go b/internal/api/handler.go index 0ea5aa3..2a05022 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -8,38 +8,41 @@ import ( "github.com/yusing/go-proxy/internal/api/v1/auth" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" + config "github.com/yusing/go-proxy/internal/config/types" ) type ServeMux struct{ *http.ServeMux } -func NewServeMux() ServeMux { - return ServeMux{http.NewServeMux()} -} - func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler)) } -func NewHandler() http.Handler { - mux := NewServeMux() +func NewHandler(cfg config.ConfigInstance) http.Handler { + mux := ServeMux{http.NewServeMux()} mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("POST", "/v1/login", auth.LoginHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) - mux.HandleFunc("POST", "/v1/reload", v1.Reload) - mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(v1.List)) - mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(v1.List)) - mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(v1.List)) + mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) + mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) + mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) + mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile) - mux.HandleFunc("GET", "/v1/stats", v1.Stats) - mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) + mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) + mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) return mux } +func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(cfg, w, r) + } +} + // allow only requests to API server with localhost. func checkHost(f http.HandlerFunc) http.HandlerFunc { if common.IsDebug { @@ -55,4 +58,4 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc { LogDebug(r).Interface("headers", r.Header).Msg("API request") f(w, r) } -} +} \ No newline at end of file diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index bc8720b..63f3aa3 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -9,7 +9,7 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/route/provider" diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 85d41a5..52617b4 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -6,9 +6,10 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/net/http/middleware" - "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/routes" + route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" ) @@ -24,7 +25,7 @@ const ( ListTasks = "tasks" ) -func List(w http.ResponseWriter, r *http.Request) { +func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { what := r.PathValue("what") if what == "" { what = ListRoutes @@ -40,7 +41,7 @@ func List(w http.ResponseWriter, r *http.Request) { U.RespondJSON(w, r, route) } case ListRoutes: - U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type")))) + U.RespondJSON(w, r, routes.RoutesByAlias(route.RouteType(r.FormValue("type")))) case ListFiles: listFiles(w, r) case ListMiddlewares: @@ -48,9 +49,9 @@ func List(w http.ResponseWriter, r *http.Request) { case ListMiddlewareTraces: U.RespondJSON(w, r, middleware.GetAllTrace()) case ListMatchDomains: - U.RespondJSON(w, r, config.Value().MatchDomains) + U.RespondJSON(w, r, cfg.Value().MatchDomains) case ListHomepageConfig: - U.RespondJSON(w, r, config.HomepageConfig()) + U.RespondJSON(w, r, routes.HomepageConfig(cfg.Value().Homepage.UseDefaultCategories)) case ListTasks: U.RespondJSON(w, r, task.DebugTaskList()) default: @@ -60,9 +61,9 @@ func List(w http.ResponseWriter, r *http.Request) { func listRoute(which string) any { if which == "" || which == "all" { - return config.RoutesByAlias() + return routes.RoutesByAlias() } - routes := config.RoutesByAlias() + routes := routes.RoutesByAlias() route, ok := routes[which] if !ok { return nil diff --git a/internal/api/v1/reload.go b/internal/api/v1/reload.go index 42a4198..defa4e4 100644 --- a/internal/api/v1/reload.go +++ b/internal/api/v1/reload.go @@ -4,11 +4,11 @@ import ( "net/http" U "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" ) -func Reload(w http.ResponseWriter, r *http.Request) { - if err := config.Reload(); err != nil { +func Reload(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { + if err := cfg.Reload(); err != nil { U.HandleErr(w, r, err) return } diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index 99ba304..0d9617b 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -9,25 +9,25 @@ import ( "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) -func Stats(w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, getStats()) +func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { + U.RespondJSON(w, r, getStats(cfg)) } -func StatsWS(w http.ResponseWriter, r *http.Request) { +func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { var originPats []string localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - if len(config.Value().MatchDomains) == 0 { + if len(cfg.Value().MatchDomains) == 0 { U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins") originPats = []string{"*"} } else { - originPats = make([]string, len(config.Value().MatchDomains)) - for i, domain := range config.Value().MatchDomains { + originPats = make([]string, len(cfg.Value().MatchDomains)) + for i, domain := range cfg.Value().MatchDomains { originPats[i] = "*" + domain } originPats = append(originPats, localAddresses...) @@ -52,7 +52,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { defer ticker.Stop() for range ticker.C { - stats := getStats() + stats := getStats(cfg) if err := wsjson.Write(ctx, conn, stats); err != nil { U.LogError(r).Msg("failed to write JSON") return @@ -62,9 +62,9 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { var startTime = time.Now() -func getStats() map[string]any { +func getStats(cfg config.ConfigInstance) map[string]any { return map[string]any{ - "proxies": config.Statistics(), + "proxies": cfg.Statistics(), "uptime": strutils.FormatDuration(time.Since(startTime)), } } diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 399cb75..5f172da 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -18,8 +18,6 @@ func (p *Provider) Setup() (err E.Error) { } } - p.ScheduleRenewal() - for _, expiry := range p.GetExpiries() { logger.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) break diff --git a/internal/config/config.go b/internal/config/config.go index f1c42b7..da9b89e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,12 +7,15 @@ import ( "sync" "time" + "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/entrypoint" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/metrics" + "github.com/yusing/go-proxy/internal/net/http/server" "github.com/yusing/go-proxy/internal/notif" proxy "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/task" @@ -26,7 +29,9 @@ type Config struct { value *types.Config providers F.Map[string, *proxy.Provider] autocertProvider *autocert.Provider - task *task.Task + entrypoint *entrypoint.Entrypoint + + task *task.Task } var ( @@ -45,15 +50,18 @@ Make sure you rename it back before next time you start.` You may run "ls-config" to show or dump the current config.` ) +var Validate = types.Validate + func GetInstance() *Config { return instance } func newConfig() *Config { return &Config{ - value: types.DefaultConfig(), - providers: F.NewMapOf[string, *proxy.Provider](), - task: task.RootTask("config", false), + value: types.DefaultConfig(), + providers: F.NewMapOf[string, *proxy.Provider](), + entrypoint: entrypoint.NewEntrypoint(), + task: task.RootTask("config", false), } } @@ -66,11 +74,6 @@ func Load() (*Config, E.Error) { return instance, instance.load() } -func Validate(data []byte) E.Error { - var model types.Config - return utils.DeserializeYAML(data, &model) -} - func MatchDomains() []string { return instance.value.MatchDomains } @@ -101,6 +104,7 @@ func OnConfigChange(ev []events.Event) { } if err := Reload(); err != nil { + logger.Warn().Msg("using last config") // recovered in event queue panic(err) } @@ -122,15 +126,19 @@ func Reload() E.Error { // -> replace config -> start new subtasks instance.task.Finish("config changed") instance = newCfg - instance.StartProxyProviders() + instance.Start() return nil } -func Value() types.Config { - return *instance.value +func (cfg *Config) Value() *types.Config { + return instance.value } -func GetAutoCertProvider() *autocert.Provider { +func (cfg *Config) Reload() E.Error { + return Reload() +} + +func (cfg *Config) AutoCertProvider() *autocert.Provider { return instance.autocertProvider } @@ -138,6 +146,26 @@ func (cfg *Config) Task() *task.Task { return cfg.task } +func (cfg *Config) Start() { + cfg.StartAutoCert() + cfg.StartProxyProviders() + cfg.StartServers() +} + +func (cfg *Config) StartAutoCert() { + autocert := cfg.autocertProvider + if autocert == nil { + logging.Info().Msg("autocert not configured") + return + } + + if err := autocert.Setup(); err != nil { + E.LogFatal("autocert setup error", err) + } else { + autocert.ScheduleRenewal(cfg.task) + } +} + func (cfg *Config) StartProxyProviders() { errs := cfg.providers.CollectErrorsParallel( func(_ string, p *proxy.Provider) error { @@ -149,6 +177,30 @@ func (cfg *Config) StartProxyProviders() { } } +func (cfg *Config) StartServers() { + server.StartServer(cfg.task, server.Options{ + Name: "proxy", + CertProvider: cfg.AutoCertProvider(), + HTTPAddr: common.ProxyHTTPAddr, + HTTPSAddr: common.ProxyHTTPSAddr, + Handler: cfg.entrypoint, + }) + server.StartServer(cfg.task, server.Options{ + Name: "api", + CertProvider: cfg.AutoCertProvider(), + HTTPAddr: common.APIHTTPAddr, + Handler: api.NewHandler(cfg), + }) + if common.PrometheusEnabled { + server.StartServer(cfg.task, server.Options{ + Name: "metrics", + CertProvider: cfg.AutoCertProvider(), + HTTPAddr: common.MetricsHTTPAddr, + Handler: metrics.NewHandler(), + }) + } +} + func (cfg *Config) load() E.Error { const errMsg = "config load error" @@ -164,8 +216,8 @@ func (cfg *Config) load() E.Error { // errors are non fatal below errs := E.NewBuilder(errMsg) - errs.Add(entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) - errs.Add(entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) + errs.Add(cfg.entrypoint.SetMiddlewares(model.Entrypoint.Middlewares)) + errs.Add(cfg.entrypoint.SetAccessLogger(cfg.task, model.Entrypoint.AccessLog)) errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initAutoCert(model.AutoCert)) errs.Add(cfg.loadRouteProviders(&model.Providers)) @@ -176,7 +228,8 @@ func (cfg *Config) load() E.Error { model.MatchDomains[i] = "." + domain } } - entrypoint.SetFindRouteDomains(model.MatchDomains) + cfg.entrypoint.SetFindRouteDomains(model.MatchDomains) + return errs.Error() } diff --git a/internal/config/query.go b/internal/config/query.go index 4634b7a..a3533f3 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -1,20 +1,14 @@ package config import ( - "strings" - - "github.com/yusing/go-proxy/internal/homepage" route "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/route/entry" - proxy "github.com/yusing/go-proxy/internal/route/provider" - "github.com/yusing/go-proxy/internal/route/routes" + "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/route/types" - "github.com/yusing/go-proxy/internal/utils/strutils" ) -func DumpEntries() map[string]*types.RawEntry { +func (cfg *Config) DumpEntries() map[string]*types.RawEntry { entries := make(map[string]*types.RawEntry) - instance.providers.RangeAll(func(_ string, p *proxy.Provider) { + cfg.providers.RangeAll(func(_ string, p *provider.Provider) { p.RangeRoutes(func(alias string, r *route.Route) { entries[alias] = r.Entry }) @@ -22,107 +16,20 @@ func DumpEntries() map[string]*types.RawEntry { return entries } -func DumpProviders() map[string]*proxy.Provider { - entries := make(map[string]*proxy.Provider) - instance.providers.RangeAll(func(name string, p *proxy.Provider) { +func (cfg *Config) DumpProviders() map[string]*provider.Provider { + entries := make(map[string]*provider.Provider) + cfg.providers.RangeAll(func(name string, p *provider.Provider) { entries[name] = p }) return entries } -func HomepageConfig() homepage.Config { - hpCfg := homepage.NewHomePageConfig() - routes.GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { - en := r.RawEntry() - item := en.Homepage - if item == nil { - item = new(homepage.Item) - item.Show = true - } - - if !item.IsEmpty() { - item.Show = true - } - - if !item.Show { - return - } - - item.Alias = alias - - if item.Name == "" { - item.Name = strutils.Title( - strings.ReplaceAll( - strings.ReplaceAll(alias, "-", " "), - "_", " ", - ), - ) - } - - if instance.value.Homepage.UseDefaultCategories { - if en.Container != nil && item.Category == "" { - if category, ok := homepage.PredefinedCategories[en.Container.ImageName]; ok { - item.Category = category - } - } - - if item.Category == "" { - if category, ok := homepage.PredefinedCategories[strings.ToLower(alias)]; ok { - item.Category = category - } - } - } - - switch { - case entry.IsDocker(r): - if item.Category == "" { - item.Category = "Docker" - } - item.SourceType = string(proxy.ProviderTypeDocker) - case entry.UseLoadBalance(r): - if item.Category == "" { - item.Category = "Load-balanced" - } - item.SourceType = "loadbalancer" - default: - if item.Category == "" { - item.Category = "Others" - } - item.SourceType = string(proxy.ProviderTypeFile) - } - - item.AltURL = r.TargetURL().String() - hpCfg.Add(item) - }) - return hpCfg -} - -func RoutesByAlias(typeFilter ...route.RouteType) map[string]any { - rts := make(map[string]any) - if len(typeFilter) == 0 || typeFilter[0] == "" { - typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream} - } - for _, t := range typeFilter { - switch t { - case route.RouteTypeReverseProxy: - routes.GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { - rts[alias] = r - }) - case route.RouteTypeStream: - routes.GetStreamRoutes().RangeAll(func(alias string, r types.StreamRoute) { - rts[alias] = r - }) - } - } - return rts -} - -func Statistics() map[string]any { +func (cfg *Config) Statistics() map[string]any { nTotalStreams := 0 nTotalRPs := 0 - providerStats := make(map[string]proxy.ProviderStats) + providerStats := make(map[string]provider.ProviderStats) - instance.providers.RangeAll(func(name string, p *proxy.Provider) { + cfg.providers.RangeAll(func(name string, p *provider.Provider) { stats := p.Statistics() providerStats[name] = stats diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 2e27891..9590337 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -3,6 +3,8 @@ package types import ( "github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/utils" + + E "github.com/yusing/go-proxy/internal/error" ) type ( @@ -24,6 +26,12 @@ type ( AccessLog *accesslog.Config `json:"access_log" validate:"omitempty"` } NotificationConfig map[string]any + + ConfigInstance interface { + Value() *Config + Reload() E.Error + Statistics() map[string]any + } ) func DefaultConfig() *Config { @@ -35,6 +43,11 @@ func DefaultConfig() *Config { } } +func Validate(data []byte) E.Error { + var model Config + return utils.DeserializeYAML(data, &model) +} + func init() { utils.RegisterDefaultValueFactory(DefaultConfig) } diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 0046a93..a2f5c0a 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "strings" - "sync" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/accesslog" @@ -17,32 +16,31 @@ import ( "github.com/yusing/go-proxy/internal/utils/strutils" ) -var findRouteFunc = findRouteAnyDomain - -var ( - epMiddleware *middleware.Middleware - epMiddlewareMu sync.Mutex - - epAccessLogger *accesslog.AccessLogger - epAccessLoggerMu sync.Mutex -) +type Entrypoint struct { + middleware *middleware.Middleware + accessLogger *accesslog.AccessLogger + findRouteFunc func(host string) (route.HTTPRoute, error) +} var ErrNoSuchRoute = errors.New("no such route") -func SetFindRouteDomains(domains []string) { - if len(domains) == 0 { - findRouteFunc = findRouteAnyDomain - } else { - findRouteFunc = findRouteByDomains(domains) +func NewEntrypoint() *Entrypoint { + return &Entrypoint{ + findRouteFunc: findRouteAnyDomain, } } -func SetMiddlewares(mws []map[string]any) error { - epMiddlewareMu.Lock() - defer epMiddlewareMu.Unlock() +func (ep *Entrypoint) SetFindRouteDomains(domains []string) { + if len(domains) == 0 { + ep.findRouteFunc = findRouteAnyDomain + } else { + ep.findRouteFunc = findRouteByDomains(domains) + } +} +func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error { if len(mws) == 0 { - epMiddleware = nil + ep.middleware = nil return nil } @@ -50,22 +48,19 @@ func SetMiddlewares(mws []map[string]any) error { if err != nil { return err } - epMiddleware = mid + ep.middleware = mid logger.Debug().Msg("entrypoint middleware loaded") return nil } -func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { - epAccessLoggerMu.Lock() - defer epAccessLoggerMu.Unlock() - +func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { if cfg == nil { - epAccessLogger = nil + ep.accessLogger = nil return } - epAccessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) + ep.accessLogger, err = accesslog.NewFileAccessLogger(parent, cfg) if err != nil { return } @@ -73,28 +68,18 @@ func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { return } -func Handler(w http.ResponseWriter, r *http.Request) { - mux, err := findRouteFunc(r.Host) +func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { + mux, err := ep.findRouteFunc(r.Host) if err == nil { - if epAccessLogger != nil { - epMiddlewareMu.Lock() - if epAccessLogger != nil { - w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { - epAccessLogger.Log(r, resp) - return nil - }) - } - epMiddlewareMu.Unlock() + if ep.accessLogger != nil { + w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { + ep.accessLogger.Log(r, resp) + return nil + }) } - if epMiddleware != nil { - epMiddlewareMu.Lock() - if epMiddleware != nil { - mid := epMiddleware - epMiddlewareMu.Unlock() - mid.ServeHTTP(mux.ServeHTTP, w, r) - return - } - epMiddlewareMu.Unlock() + if ep.middleware != nil { + ep.middleware.ServeHTTP(mux.ServeHTTP, w, r) + return } mux.ServeHTTP(w, r) return diff --git a/internal/entrypoint/entrypoint_test.go b/internal/entrypoint/entrypoint_test.go index 65153a4..438bc1f 100644 --- a/internal/entrypoint/entrypoint_test.go +++ b/internal/entrypoint/entrypoint_test.go @@ -8,18 +8,19 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) -var r route.HTTPRoute +var ( + r route.HTTPRoute + ep = NewEntrypoint() +) func run(t *testing.T, match []string, noMatch []string) { t.Helper() t.Cleanup(routes.TestClear) - t.Cleanup(func() { - SetFindRouteDomains(nil) - }) + t.Cleanup(func() { ep.SetFindRouteDomains(nil) }) for _, test := range match { t.Run(test, func(t *testing.T) { - found, err := findRouteFunc(test) + found, err := ep.findRouteFunc(test) ExpectNoError(t, err) ExpectTrue(t, found == &r) }) @@ -27,7 +28,7 @@ func run(t *testing.T, match []string, noMatch []string) { for _, test := range noMatch { t.Run(test, func(t *testing.T) { - _, err := findRouteFunc(test) + _, err := ep.findRouteFunc(test) ExpectError(t, ErrNoSuchRoute, err) }) } @@ -72,7 +73,7 @@ func TestFindRouteExactHostMatch(t *testing.T) { } func TestFindRouteByDomains(t *testing.T) { - SetFindRouteDomains([]string{ + ep.SetFindRouteDomains([]string{ ".domain.com", ".sub.domain.com", }) @@ -97,7 +98,7 @@ func TestFindRouteByDomains(t *testing.T) { } func TestFindRouteByDomainsExactMatch(t *testing.T) { - SetFindRouteDomains([]string{ + ep.SetFindRouteDomains([]string{ ".domain.com", ".sub.domain.com", }) diff --git a/internal/error/subject.go b/internal/error/subject.go index b649f0e..f46727d 100644 --- a/internal/error/subject.go +++ b/internal/error/subject.go @@ -8,8 +8,8 @@ import ( //nolint:errname type withSubject struct { - Subject string `json:"subject"` - Err error `json:"err"` + Subjects []string `json:"subjects"` + Err error `json:"err"` } const subjectSep = " > " @@ -30,13 +30,18 @@ func PrependSubject(subject string, err error) error { case Error: return err.Subject(subject) } - return &withSubject{subject, err} + return &withSubject{[]string{subject}, err} } func (err *withSubject) Prepend(subject string) *withSubject { clone := *err if subject != "" { - clone.Subject = subject + subjectSep + clone.Subject + switch subject[0] { + case '[', '(', '{': + clone.Subjects[len(clone.Subjects)-1] += subject + default: + clone.Subjects = append(clone.Subjects, subject) + } } return &clone } @@ -50,7 +55,22 @@ func (err *withSubject) Unwrap() error { } func (err *withSubject) Error() string { - subjects := strings.Split(err.Subject, subjectSep) - subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1]) - return strings.Join(subjects, subjectSep) + ": " + err.Err.Error() + // subject is in reversed order + n := len(err.Subjects) + size := 0 + errStr := err.Err.Error() + var sb strings.Builder + for _, s := range err.Subjects { + size += len(s) + } + sb.Grow(size + 2 + n*len(subjectSep) + len(errStr)) + + for i := n - 1; i > 0; i-- { + sb.WriteString(err.Subjects[i]) + sb.WriteString(subjectSep) + } + sb.WriteString(highlight(err.Subjects[0])) + sb.WriteString(": ") + sb.WriteString(errStr) + return sb.String() } diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index d6726a6..9eb899e 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -4,7 +4,10 @@ import ( "net" "net/http" + "github.com/go-playground/validator/v10" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -16,7 +19,7 @@ type ( } CIDRWhitelistOpts struct { Allow []*types.CIDR `validate:"min=1"` - StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,gte=400,lte=599"` + StatusCode int `json:"status_code" aliases:"status" validate:"omitempty,status_code"` Message string } ) @@ -30,6 +33,13 @@ var ( } ) +func init() { + utils.Validator().RegisterValidation("status_code", func(fl validator.FieldLevel) bool { + statusCode := fl.Field().Int() + return gphttp.IsStatusCodeValid(int(statusCode)) + }) +} + // setup implements MiddlewareWithSetup. func (wl *cidrWhitelist) setup() { wl.CIDRWhitelistOpts = cidrWhitelistDefaults diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index b9bd3a1..64fc9e8 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -24,6 +24,18 @@ func TestCIDRWhitelistValidation(t *testing.T) { "message": testMessage, }) ExpectNoError(t, err) + _, err = CIDRWhiteList.New(OptionsRaw{ + "allow": []string{"192.168.2.100/32"}, + "message": testMessage, + "status": 403, + }) + ExpectNoError(t, err) + _, err = CIDRWhiteList.New(OptionsRaw{ + "allow": []string{"192.168.2.100/32"}, + "message": testMessage, + "status_code": 403, + }) + ExpectNoError(t, err) }) t.Run("missing allow", func(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ diff --git a/internal/net/http/reverseproxy/reverse_proxy_mod.go b/internal/net/http/reverseproxy/reverse_proxy_mod.go index 91f7754..0bb0d4b 100644 --- a/internal/net/http/reverseproxy/reverse_proxy_mod.go +++ b/internal/net/http/reverseproxy/reverse_proxy_mod.go @@ -168,24 +168,6 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // URLs to the scheme, host, and base path provided in target. If the // target's path is "/base" and the incoming request was for "/dir", // the target request will be for /base/dir. -// -// NewReverseProxy does not rewrite the Host header. -// -// To customize the ReverseProxy behavior beyond what -// NewReverseProxy provides, use ReverseProxy directly -// with a Rewrite function. The ProxyRequest SetURL method -// may be used to route the outbound request. (Note that SetURL, -// unlike NewReverseProxy, rewrites the Host header -// of the outbound request by default.) -// -// proxy := &ReverseProxy{ -// Rewrite: func(r *ProxyRequest) { -// r.SetURL(target) -// r.Out.Host = r.In.Host // if desired -// }, -// } -// - func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) *ReverseProxy { if transport == nil { panic("nil transport") diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index fde17fd..ce30d01 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -35,9 +35,9 @@ type Options struct { Handler http.Handler } -func StartServer(opt Options) (s *Server) { +func StartServer(parent task.Parent, opt Options) (s *Server) { s = NewServer(opt) - s.Start() + s.Start(parent) return s } @@ -83,11 +83,13 @@ func NewServer(opt Options) (s *Server) { // If both are not set, this does nothing. // // Start() is non-blocking. -func (s *Server) Start() { +func (s *Server) Start(parent task.Parent) { if s.http == nil && s.https == nil { return } + task := parent.Subtask("server."+s.Name, false) + s.startTime = time.Now() if s.http != nil { go func() { @@ -105,7 +107,7 @@ func (s *Server) Start() { s.l.Info().Str("addr", s.https.Addr).Msgf("server started") } - task.OnProgramExit("server."+s.Name+".stop", s.stop) + task.OnCancel("stop", s.stop) } func (s *Server) stop() { @@ -113,14 +115,19 @@ func (s *Server) stop() { return } + ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) + defer cancel() + if s.http != nil && s.httpStarted { - s.handleErr("http", s.http.Shutdown(task.RootContext())) + s.handleErr("http", s.http.Shutdown(ctx)) s.httpStarted = false + s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped") } if s.https != nil && s.httpsStarted { - s.handleErr("https", s.https.Shutdown(task.RootContext())) + s.handleErr("https", s.https.Shutdown(ctx)) s.httpsStarted = false + s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped") } } diff --git a/internal/route/provider/docker_labels.yaml b/internal/route/provider/docker_labels.yaml index 8065654..8ab6c21 100644 --- a/internal/route/provider/docker_labels.yaml +++ b/internal/route/provider/docker_labels.yaml @@ -72,9 +72,9 @@ proxy.app1.host: 10.0.0.254 proxy.app1.port: 80 proxy.app1.path_patterns: | # Check https://pkg.go.dev/net/http#hdr-Patterns-ServeMux for syntax - GET / # accept any GET request - POST /auth # for /auth and /auth/* accept only POST - GET /home/{$} # for exactly /home + - GET / # accept any GET request + - POST /auth # for /auth and /auth/* accept only POST + - GET /home/{$} # for exactly /home proxy.app1.healthcheck.disabled: false proxy.app1.healthcheck.path: / proxy.app1.healthcheck.interval: 5s diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 46a95b6..45eb1eb 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -5,6 +5,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route/entry" + "github.com/yusing/go-proxy/internal/route/provider/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher" ) @@ -87,10 +88,10 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { switch handler.provider.GetType() { - case ProviderTypeDocker: + case types.ProviderTypeDocker: return route.Entry.Container.ContainerID == event.ActorID || route.Entry.Container.ContainerName == event.ActorName - case ProviderTypeFile: + case types.ProviderTypeFile: return true } // should never happen diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 9933bb3..32df015 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -10,6 +10,8 @@ import ( "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" R "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/route/provider/types" + route "github.com/yusing/go-proxy/internal/route/types" "github.com/yusing/go-proxy/internal/task" W "github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher/events" @@ -20,7 +22,7 @@ type ( ProviderImpl `json:"-"` name string - t ProviderType + t types.ProviderType routes R.Routes watcher W.Watcher @@ -31,24 +33,20 @@ type ( NewWatcher() W.Watcher Logger() *zerolog.Logger } - ProviderType string ProviderStats struct { - NumRPs int `json:"num_reverse_proxies"` - NumStreams int `json:"num_streams"` - Type ProviderType `json:"type"` + NumRPs int `json:"num_reverse_proxies"` + NumStreams int `json:"num_streams"` + Type types.ProviderType `json:"type"` } ) const ( - ProviderTypeDocker ProviderType = "docker" - ProviderTypeFile ProviderType = "file" - providerEventFlushInterval = 300 * time.Millisecond ) var ErrEmptyProviderName = errors.New("empty provider name") -func newProvider(name string, t ProviderType) *Provider { +func newProvider(name string, t types.ProviderType) *Provider { return &Provider{ name: name, t: t, @@ -61,7 +59,7 @@ func NewFileProvider(filename string) (p *Provider, err error) { if name == "" { return nil, ErrEmptyProviderName } - p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile) + p = newProvider(strings.ReplaceAll(name, ".", "_"), types.ProviderTypeFile) p.ProviderImpl, err = FileProviderImpl(filename) if err != nil { return nil, err @@ -75,7 +73,7 @@ func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) return nil, ErrEmptyProviderName } - p = newProvider(name, ProviderTypeDocker) + p = newProvider(name, types.ProviderTypeDocker) p.ProviderImpl, err = DockerProviderImpl(name, dockerHost, p.IsExplicitOnly()) if err != nil { return nil, err @@ -92,7 +90,7 @@ func (p *Provider) GetName() string { return p.name } -func (p *Provider) GetType() ProviderType { +func (p *Provider) GetType() types.ProviderType { return p.t } @@ -171,9 +169,9 @@ func (p *Provider) Statistics() ProviderStats { numStreams := 0 p.routes.RangeAll(func(_ string, r *R.Route) { switch r.Type { - case R.RouteTypeReverseProxy: + case route.RouteTypeReverseProxy: numRPs++ - case R.RouteTypeStream: + case route.RouteTypeStream: numStreams++ } }) diff --git a/internal/route/provider/types/provider_type.go b/internal/route/provider/types/provider_type.go new file mode 100644 index 0000000..2907762 --- /dev/null +++ b/internal/route/provider/types/provider_type.go @@ -0,0 +1,8 @@ +package types + +type ProviderType string + +const ( + ProviderTypeDocker ProviderType = "docker" + ProviderTypeFile ProviderType = "file" +) diff --git a/internal/route/route.go b/internal/route/route.go index 4101d9a..003b912 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -14,11 +14,10 @@ import ( ) type ( - RouteType string - Route struct { + Route struct { _ U.NoCopy impl - Type RouteType + Type types.RouteType Entry *RawEntry } Routes = F.Map[string, *Route] @@ -34,11 +33,6 @@ type ( RawEntries = types.RawEntries ) -const ( - RouteTypeStream RouteType = "stream" - RouteTypeReverseProxy RouteType = "reverse_proxy" -) - // function alias. var ( NewRoutes = F.NewMap[Routes] @@ -59,15 +53,15 @@ func NewRoute(raw *RawEntry) (*Route, E.Error) { return nil, err } - var t RouteType + var t types.RouteType var rt impl switch e := en.(type) { case *entry.StreamEntry: - t = RouteTypeStream + t = types.RouteTypeStream rt, err = NewStreamRoute(e) case *entry.ReverseProxyEntry: - t = RouteTypeReverseProxy + t = types.RouteTypeReverseProxy rt, err = NewHTTPRoute(e) default: panic("bug: should not reach here") diff --git a/internal/route/routes/query.go b/internal/route/routes/query.go new file mode 100644 index 0000000..cb87ae3 --- /dev/null +++ b/internal/route/routes/query.go @@ -0,0 +1,99 @@ +package routes + +import ( + "strings" + + "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/route/entry" + provider "github.com/yusing/go-proxy/internal/route/provider/types" + "github.com/yusing/go-proxy/internal/route/types" + route "github.com/yusing/go-proxy/internal/route/types" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +func HomepageConfig(useDefaultCategories bool) homepage.Config { + hpCfg := homepage.NewHomePageConfig() + GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { + en := r.RawEntry() + item := en.Homepage + if item == nil { + item = new(homepage.Item) + item.Show = true + } + + if !item.IsEmpty() { + item.Show = true + } + + if !item.Show { + return + } + + item.Alias = alias + + if item.Name == "" { + item.Name = strutils.Title( + strings.ReplaceAll( + strings.ReplaceAll(alias, "-", " "), + "_", " ", + ), + ) + } + + if useDefaultCategories { + if en.Container != nil && item.Category == "" { + if category, ok := homepage.PredefinedCategories[en.Container.ImageName]; ok { + item.Category = category + } + } + + if item.Category == "" { + if category, ok := homepage.PredefinedCategories[strings.ToLower(alias)]; ok { + item.Category = category + } + } + } + + switch { + case entry.IsDocker(r): + if item.Category == "" { + item.Category = "Docker" + } + item.SourceType = string(provider.ProviderTypeDocker) + case entry.UseLoadBalance(r): + if item.Category == "" { + item.Category = "Load-balanced" + } + item.SourceType = "loadbalancer" + default: + if item.Category == "" { + item.Category = "Others" + } + item.SourceType = string(provider.ProviderTypeFile) + } + + item.AltURL = r.TargetURL().String() + hpCfg.Add(item) + }) + return hpCfg +} + +func RoutesByAlias(typeFilter ...route.RouteType) map[string]any { + rts := make(map[string]any) + if len(typeFilter) == 0 || typeFilter[0] == "" { + typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream} + } + for _, t := range typeFilter { + switch t { + case route.RouteTypeReverseProxy: + GetHTTPRoutes().RangeAll(func(alias string, r types.HTTPRoute) { + rts[alias] = r + }) + case route.RouteTypeStream: + GetStreamRoutes().RangeAll(func(alias string, r types.StreamRoute) { + rts[alias] = r + }) + } + } + return rts +} diff --git a/internal/route/types/route_type.go b/internal/route/types/route_type.go new file mode 100644 index 0000000..f5357db --- /dev/null +++ b/internal/route/types/route_type.go @@ -0,0 +1,8 @@ +package types + +type RouteType string + +const ( + RouteTypeStream RouteType = "stream" + RouteTypeReverseProxy RouteType = "reverse_proxy" +) diff --git a/next-release.md b/next-release.md index 222fc3b..a40874c 100644 --- a/next-release.md +++ b/next-release.md @@ -8,21 +8,21 @@ GoDoxy v0.8.2 expected changes Sample service showing this: ```yaml - hello-world: - image: nginxdemos/hello - container_name: hello-world - restart: "no" - ports: - - "9100:80" - labels: - proxy.aliases: hello-world - proxy.#1.port: 9100 - proxy.idle_timeout: 45s - proxy.wake_timeout: 30s - proxy.stop_method: stop - proxy.stop_timeout: 10s - proxy.stop_signal: SIGTERM - proxy.start_endpoint: "/start" + hello-world: + image: nginxdemos/hello + container_name: hello-world + restart: "no" + ports: + - "9100:80" + labels: + proxy.aliases: hello-world + proxy.#1.port: 9100 + proxy.idle_timeout: 45s + proxy.wake_timeout: 30s + proxy.stop_method: stop + proxy.stop_timeout: 10s + proxy.stop_signal: SIGTERM + proxy.start_endpoint: "/start" ``` Hitting `/` on this service when the container is down: @@ -38,14 +38,14 @@ GoDoxy v0.8.2 expected changes > Host: hello-world.godoxy.local > User-Agent: curl/8.7.1 > Accept: */* - > + > * Request completely sent off < HTTP/1.1 403 Forbidden < Content-Type: text/plain; charset=utf-8 < X-Content-Type-Options: nosniff < Date: Wed, 08 Jan 2025 02:04:51 GMT < Content-Length: 71 - < + < Forbidden: Container can only be started via configured start endpoint * Connection #0 to host localhost left intact ``` @@ -64,16 +64,17 @@ GoDoxy v0.8.2 expected changes > User-Agent: curl/8.7.1 > Accept: */* > X-Goproxy-Check-Redirect: skip - > + > * Request completely sent off < HTTP/1.1 200 OK < Date: Wed, 08 Jan 2025 02:13:39 GMT < Content-Length: 0 - < + < * Connection #0 to host localhost left intact ``` - Caddyfile like rules + ```yaml proxy.goaccess.rules: | - name: default @@ -92,4 +93,21 @@ GoDoxy v0.8.2 expected changes - name: block POST and PUT on: method POST | method PUT do: error 403 Forbidden -``` \ No newline at end of file + ``` + +```` + +- config reload will now cause all servers to fully restart (i.e. proxy, api, prometheus, etc) + +- multiline-string as list now treated as YAML list, which requires hyphen prefix `-`, i.e. + ```yaml + proxy.app.middlewares.request.hide_headers: + - X-Header1 + - X-Header2 +```` + +- autocert now supports hot-reload + +- Fixes + - bug: cert renewal failure no longer causes renew schdueler to stuck forever + - bug: access log writes to closed file after config reload