diff --git a/Makefile b/Makefile index c1f85a5..f37c7eb 100755 --- a/Makefile +++ b/Makefile @@ -30,6 +30,12 @@ get: debug: make build && sudo GOPROXY_DEBUG=1 bin/go-proxy +debug-trace: + make build && sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy + +profile: + GODEBUG=gctrace=1 make build && sudo GOPROXY_DEBUG=1 bin/go-proxy + mtrace: bin/go-proxy debug-ls-mtrace > mtrace.json diff --git a/cmd/main.go b/cmd/main.go index 0dc0ecf..79b9a92 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -20,6 +20,7 @@ import ( "github.com/yusing/go-proxy/internal/net/http/middleware" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/server" + "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/pkg" ) @@ -32,8 +33,14 @@ func main() { } l := logrus.WithField("module", "main") + timeFmt := "01-02 15:04:05" + fullTS := true - if common.IsDebug { + if common.IsTrace { + logrus.SetLevel(logrus.TraceLevel) + timeFmt = "04:05" + fullTS = false + } else if common.IsDebug { logrus.SetLevel(logrus.DebugLevel) } @@ -42,9 +49,9 @@ func main() { } else { logrus.SetFormatter(&logrus.TextFormatter{ DisableSorting: true, - FullTimestamp: true, + FullTimestamp: fullTS, ForceColors: true, - TimestampFormat: "01-02 15:04:05", + TimestampFormat: timeFmt, }) logrus.Infof("go-proxy version %s", pkg.GetVersion()) } @@ -76,21 +83,22 @@ func main() { middleware.LoadComposeFiles() - if err := config.Load(); err != nil { + var cfg *config.Config + var err E.NestedError + if cfg, err = config.Load(); err != nil { logrus.Warn(err) } - cfg := config.GetInstance() switch args.Command { case common.CommandListConfigs: - printJSON(cfg.Value()) + printJSON(config.Value()) return case common.CommandListRoutes: routes, err := query.ListRoutes() if err != nil { log.Printf("failed to connect to api server: %s", err) log.Printf("falling back to config file") - printJSON(cfg.RoutesByAlias()) + printJSON(config.RoutesByAlias()) } else { printJSON(routes) } @@ -103,10 +111,10 @@ func main() { printJSON(icons) return case common.CommandDebugListEntries: - printJSON(cfg.DumpEntries()) + printJSON(config.DumpEntries()) return case common.CommandDebugListProviders: - printJSON(cfg.DumpProviders()) + printJSON(config.DumpProviders()) return case common.CommandDebugListMTrace: trace, err := query.ListMiddlewareTraces() @@ -114,17 +122,25 @@ func main() { log.Fatal(err) } printJSON(trace) + return + case common.CommandDebugListTasks: + tasks, err := query.DebugListTasks() + if err != nil { + log.Fatal(err) + } + printJSON(tasks) + return } cfg.StartProxyProviders() - cfg.WatchChanges() + config.WatchChanges() sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM) signal.Notify(sig, syscall.SIGHUP) - autocert := cfg.GetAutoCertProvider() + autocert := config.GetAutoCertProvider() if autocert != nil { if err := autocert.Setup(); err != nil { l.Fatal(err) @@ -139,14 +155,14 @@ func main() { HTTPAddr: common.ProxyHTTPAddr, HTTPSAddr: common.ProxyHTTPSAddr, Handler: http.HandlerFunc(R.ProxyHandler), - RedirectToHTTPS: cfg.Value().RedirectToHTTPS, + RedirectToHTTPS: config.Value().RedirectToHTTPS, }) apiServer := server.InitAPIServer(server.Options{ Name: "api", CertProvider: autocert, HTTPAddr: common.APIHTTPAddr, - Handler: api.NewHandler(cfg), - RedirectToHTTPS: cfg.Value().RedirectToHTTPS, + Handler: api.NewHandler(), + RedirectToHTTPS: config.Value().RedirectToHTTPS, }) proxyServer.Start() @@ -157,8 +173,8 @@ func main() { // grafully shutdown logrus.Info("shutting down") - common.CancelGlobalContext() - common.GlobalContextWait(time.Second * time.Duration(cfg.Value().TimeoutShutdown)) + task.CancelGlobalContext() + task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown)) } func prepareDirectory(dir string) { diff --git a/internal/api/handler.go b/internal/api/handler.go index 414657e..34429af 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -2,6 +2,7 @@ package api import ( "fmt" + "net" "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" @@ -21,34 +22,35 @@ func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler)) } -func NewHandler(cfg *config.Config) http.Handler { +func NewHandler() http.Handler { mux := NewServeMux() mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) - mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth)) - mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth)) - mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload)) - mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List)) - mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List)) + mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth) + mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth) + mux.HandleFunc("POST", "/v1/reload", v1.Reload) + mux.HandleFunc("GET", "/v1/list", v1.List) + mux.HandleFunc("GET", "/v1/list/{what}", v1.List) mux.HandleFunc("GET", "/v1/file", v1.GetFileContent) mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent) mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) - mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats)) - mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS)) + mux.HandleFunc("GET", "/v1/stats", v1.Stats) + mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc()) return mux } -// allow only requests to API server with host matching common.APIHTTPAddr. +// allow only requests to API server with localhost. func checkHost(f http.HandlerFunc) http.HandlerFunc { if common.IsDebug { return f } return func(w http.ResponseWriter, r *http.Request) { - if r.Host != common.APIHTTPAddr { - Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr) - http.Error(w, "invalid request", http.StatusForbidden) + host, _, _ := net.SplitHostPort(r.RemoteAddr) + if host != "127.0.0.1" && host != "localhost" && host != "[::1]" { + Logger.Warnf("blocked API request from %s", host) + http.Error(w, "forbidden", http.StatusForbidden) return } f(w, r) diff --git a/internal/api/v1/checkhealth.go b/internal/api/v1/checkhealth.go index be0de54..9360b80 100644 --- a/internal/api/v1/checkhealth.go +++ b/internal/api/v1/checkhealth.go @@ -4,11 +4,10 @@ import ( "net/http" . "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/watcher/health" ) -func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { +func CheckHealth(w http.ResponseWriter, r *http.Request) { target := r.FormValue("target") if target == "" { HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest) diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index 9944a2d..c92d3ef 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -11,7 +11,7 @@ import ( "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/proxy/provider" + "github.com/yusing/go-proxy/internal/route/provider" ) func GetFileContent(w http.ResponseWriter, r *http.Request) { diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 86208a8..1b71452 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -9,19 +9,21 @@ import ( "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" ) const ( - ListRoutes = "routes" - ListConfigFiles = "config_files" - ListMiddlewares = "middlewares" - ListMiddlewareTrace = "middleware_trace" - ListMatchDomains = "match_domains" - ListHomepageConfig = "homepage_config" + ListRoutes = "routes" + ListConfigFiles = "config_files" + ListMiddlewares = "middlewares" + ListMiddlewareTraces = "middleware_trace" + ListMatchDomains = "match_domains" + ListHomepageConfig = "homepage_config" + ListTasks = "tasks" ) -func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) { +func List(w http.ResponseWriter, r *http.Request) { what := r.PathValue("what") if what == "" { what = ListRoutes @@ -29,27 +31,24 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) { switch what { case ListRoutes: - listRoutes(cfg, w, r) + U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type")))) case ListConfigFiles: listConfigFiles(w, r) case ListMiddlewares: - listMiddlewares(w, r) - case ListMiddlewareTrace: - listMiddlewareTrace(w, r) + U.RespondJSON(w, r, middleware.All()) + case ListMiddlewareTraces: + U.RespondJSON(w, r, middleware.GetAllTrace()) case ListMatchDomains: - listMatchDomains(cfg, w, r) + U.RespondJSON(w, r, config.Value().MatchDomains) case ListHomepageConfig: - listHomepageConfig(cfg, w, r) + U.RespondJSON(w, r, config.HomepageConfig()) + case ListTasks: + U.RespondJSON(w, r, task.DebugTaskMap()) default: U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) } } -func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type"))) - U.RespondJSON(w, r, routes) -} - func listConfigFiles(w http.ResponseWriter, r *http.Request) { files, err := utils.ListFiles(common.ConfigBasePath, 1) if err != nil { @@ -61,19 +60,3 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) { } U.RespondJSON(w, r, files) } - -func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, middleware.GetAllTrace()) -} - -func listMiddlewares(w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, middleware.All()) -} - -func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, cfg.Value().MatchDomains) -} - -func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, cfg.HomepageConfig()) -} diff --git a/internal/api/v1/query/query.go b/internal/api/v1/query/query.go index b588f83..88351f8 100644 --- a/internal/api/v1/query/query.go +++ b/internal/api/v1/query/query.go @@ -34,36 +34,34 @@ func ReloadServer() E.NestedError { return nil } -func ListRoutes() (map[string]map[string]any, E.NestedError) { - resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes)) +func List[T any](what string) (_ T, outErr E.NestedError) { + resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what)) if err != nil { - return nil, E.From(err) + outErr = E.From(err) + return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, E.Failure("list routes").Extraf("status code: %v", resp.StatusCode) + outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode) + return } - var routes map[string]map[string]any - err = json.NewDecoder(resp.Body).Decode(&routes) + var res T + err = json.NewDecoder(resp.Body).Decode(&res) if err != nil { - return nil, E.From(err) + outErr = E.From(err) + return } - return routes, nil + return res, nil +} + +func ListRoutes() (map[string]map[string]any, E.NestedError) { + return List[map[string]map[string]any](v1.ListRoutes) } func ListMiddlewareTraces() (middleware.Traces, E.NestedError) { - resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace)) - if err != nil { - return nil, E.From(err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode) - } - var traces middleware.Traces - err = json.NewDecoder(resp.Body).Decode(&traces) - if err != nil { - return nil, E.From(err) - } - return traces, nil + return List[middleware.Traces](v1.ListMiddlewareTraces) +} + +func DebugListTasks() (map[string]any, E.NestedError) { + return List[map[string]any](v1.ListTasks) } diff --git a/internal/api/v1/reload.go b/internal/api/v1/reload.go index da2c3a5..ffd0609 100644 --- a/internal/api/v1/reload.go +++ b/internal/api/v1/reload.go @@ -7,8 +7,8 @@ import ( "github.com/yusing/go-proxy/internal/config" ) -func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - if err := cfg.Reload(); err != nil { +func Reload(w http.ResponseWriter, r *http.Request) { + if err := config.Reload(); err != nil { U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError) } else { w.WriteHeader(http.StatusOK) diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index 46ea7f7..c86d325 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -14,19 +14,19 @@ import ( "github.com/yusing/go-proxy/internal/utils" ) -func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, getStats(cfg)) +func Stats(w http.ResponseWriter, r *http.Request) { + U.RespondJSON(w, r, getStats()) } -func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) { +func StatsWS(w http.ResponseWriter, r *http.Request) { localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - originPats := make([]string, len(cfg.Value().MatchDomains)+len(localAddresses)) + originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses)) if len(originPats) == 0 { U.Logger.Warnf("no match domains configured, accepting websocket request from all origins") originPats = []string{"*"} } else { - for i, domain := range cfg.Value().MatchDomains { + for i, domain := range config.Value().MatchDomains { originPats[i] = "*." + domain } originPats = append(originPats, localAddresses...) @@ -51,7 +51,7 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) { defer ticker.Stop() for range ticker.C { - stats := getStats(cfg) + stats := getStats() if err := wsjson.Write(ctx, conn, stats); err != nil { U.Logger.Errorf("/stats/ws failed to write JSON: %s", err) return @@ -59,9 +59,9 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) { } } -func getStats(cfg *config.Config) map[string]any { +func getStats() map[string]any { return map[string]any{ - "proxies": cfg.Statistics(), + "proxies": config.Statistics(), "uptime": utils.FormatDuration(server.GetProxyServer().Uptime()), } } diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 23549ea..2f0740d 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -9,7 +9,7 @@ import ( "github.com/go-acme/lego/v4/lego" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/config/types" ) type Config types.AutoCertConfig diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 011c89f..243bbb7 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -13,9 +13,9 @@ import ( "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" - "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" ) @@ -140,23 +140,22 @@ func (p *Provider) ScheduleRenewal() { if p.GetName() == ProviderLocal { return } - - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - task := common.NewTask("cert renew scheduler") - defer task.Finished() - - for { - select { - case <-task.Context().Done(): - return - case <-ticker.C: // check every 5 seconds - if err := p.renewIfNeeded(); err.HasError() { - logger.Warn(err) + go func() { + task := task.GlobalTask("cert renew scheduler") + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + defer task.Finish("cert renew scheduler stopped") + for { + select { + case <-task.Context().Done(): + return + case <-ticker.C: // check every 5 seconds + if err := p.renewIfNeeded(); err.HasError() { + logger.Warn(err) + } } } - } + }() } func (p *Provider) initClient() E.NestedError { diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index ef754c1..95d6089 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -17,7 +17,7 @@ func (p *Provider) Setup() (err E.NestedError) { } } - go p.ScheduleRenewal() + p.ScheduleRenewal() for _, expiry := range p.GetExpiries() { logger.Infof("certificate expire on %s", expiry) diff --git a/internal/common/args.go b/internal/common/args.go index 947c105..b0dd597 100644 --- a/internal/common/args.go +++ b/internal/common/args.go @@ -22,6 +22,7 @@ const ( CommandDebugListEntries = "debug-ls-entries" CommandDebugListProviders = "debug-ls-providers" CommandDebugListMTrace = "debug-ls-mtrace" + CommandDebugListTasks = "debug-ls-tasks" ) var ValidCommands = []string{ @@ -35,6 +36,7 @@ var ValidCommands = []string{ CommandDebugListEntries, CommandDebugListProviders, CommandDebugListMTrace, + CommandDebugListTasks, } func GetArgs() Args { diff --git a/internal/common/constants.go b/internal/common/constants.go index 4e863cf..44c9102 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -43,7 +43,6 @@ const ( HealthCheckIntervalDefault = 5 * time.Second HealthCheckTimeoutDefault = 5 * time.Second - IdleTimeoutDefault = "0" WakeTimeoutDefault = "30s" StopTimeoutDefault = "10s" StopMethodDefault = "stop" diff --git a/internal/common/env.go b/internal/common/env.go index 038da1d..13d71fe 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -15,6 +15,7 @@ var ( NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true) IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test") IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest) + IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug ProxyHTTPAddr, ProxyHTTPHost, diff --git a/internal/common/task.go b/internal/common/task.go deleted file mode 100644 index c471aa1..0000000 --- a/internal/common/task.go +++ /dev/null @@ -1,224 +0,0 @@ -package common - -import ( - "context" - "fmt" - "strings" - "sync" - "time" - - "github.com/puzpuzpuz/xsync/v3" - "github.com/sirupsen/logrus" -) - -var ( - globalCtx, globalCtxCancel = context.WithCancel(context.Background()) - taskWg sync.WaitGroup - tasksMap = xsync.NewMapOf[*task, struct{}]() -) - -type ( - Task interface { - Name() string - Context() context.Context - Subtask(usageFmt string, args ...interface{}) Task - SubtaskWithCancel(usageFmt string, args ...interface{}) (Task, context.CancelFunc) - Finished() - } - task struct { - ctx context.Context - subtasks []*task - name string - finished bool - mu sync.Mutex - } -) - -func (t *task) Name() string { - return t.name -} - -// Context returns the context associated with the task. This context is -// canceled when the task is finished. -func (t *task) Context() context.Context { - return t.ctx -} - -// Finished marks the task as finished and notifies the global wait group. -// Finished is thread-safe and idempotent. -func (t *task) Finished() { - t.mu.Lock() - defer t.mu.Unlock() - - if t.finished { - return - } - t.finished = true - if _, ok := tasksMap.Load(t); ok { - taskWg.Done() - tasksMap.Delete(t) - } - logrus.Debugf("task %q finished", t.Name()) -} - -// Subtask returns a new subtask with the given name, derived from the receiver's context. -// -// The returned subtask is associated with the receiver's context and will be -// automatically registered and deregistered from the global task wait group. -// -// If the receiver's context is already canceled, the returned subtask will be -// canceled immediately. -// -// The returned subtask is safe for concurrent use. -func (t *task) Subtask(format string, args ...interface{}) Task { - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } - t.mu.Lock() - defer t.mu.Unlock() - sub := newSubTask(t.ctx, format) - t.subtasks = append(t.subtasks, sub) - return sub -} - -// SubtaskWithCancel returns a new subtask with the given name, derived from the receiver's context, -// and a cancel function. The returned subtask is associated with the receiver's context and will be -// automatically registered and deregistered from the global task wait group. -// -// If the receiver's context is already canceled, the returned subtask will be canceled immediately. -// -// The returned cancel function is safe for concurrent use, and can be used to cancel the returned -// subtask at any time. -func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } - t.mu.Lock() - defer t.mu.Unlock() - ctx, cancel := context.WithCancel(t.ctx) - sub := newSubTask(ctx, format) - t.subtasks = append(t.subtasks, sub) - return sub, cancel -} - -func (t *task) tree(prefix ...string) string { - var sb strings.Builder - var pre string - if len(prefix) > 0 { - pre = prefix[0] - } - sb.WriteString(pre) - sb.WriteString(t.Name() + "\n") - for _, sub := range t.subtasks { - if sub.finished { - continue - } - sb.WriteString(sub.tree(pre + " ")) - } - return sb.String() -} - -func newSubTask(ctx context.Context, name string) *task { - t := &task{ - ctx: ctx, - name: name, - } - tasksMap.Store(t, struct{}{}) - taskWg.Add(1) - logrus.Debugf("task %q started", name) - return t -} - -// NewTask returns a new Task with the given name, derived from the global -// context. -// -// The returned Task is associated with the global context and will be -// automatically registered and deregistered from the global context's wait -// group. -// -// If the global context is already canceled, the returned Task will be -// canceled immediately. -// -// The returned Task is not safe for concurrent use. -func NewTask(format string, args ...interface{}) Task { - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } - return newSubTask(globalCtx, format) -} - -// NewTaskWithCancel returns a new Task with the given name, derived from the -// global context, and a cancel function. The returned Task is associated with -// the global context and will be automatically registered and deregistered -// from the global task wait group. -// -// If the global context is already canceled, the returned Task will be -// canceled immediately. -// -// The returned Task is safe for concurrent use. -// -// The returned cancel function is safe for concurrent use, and can be used -// to cancel the returned Task at any time. -func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { - subCtx, cancel := context.WithCancel(globalCtx) - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } - return newSubTask(subCtx, format), cancel -} - -// GlobalTask returns a new Task with the given name, associated with the -// global context. -// -// Unlike NewTask, GlobalTask does not automatically register or deregister -// the Task with the global task wait group. The returned Task is not -// started, but the name is formatted immediately. -// -// This is best used for main task that do not need to wait and -// will create a bunch of subtasks. -// -// The returned Task is safe for concurrent use. -func GlobalTask(format string, args ...interface{}) Task { - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } - return &task{ - ctx: globalCtx, - name: format, - } -} - -// CancelGlobalContext cancels the global context, which will cause all tasks -// created by GlobalTask or NewTask to be canceled. This should be called -// before exiting the program to ensure that all tasks are properly cleaned -// up. -func CancelGlobalContext() { - globalCtxCancel() -} - -// GlobalContextWait waits for all tasks to finish, up to the given timeout. -// -// If the timeout is exceeded, it prints a list of all tasks that were -// still running when the timeout was reached, and their current tree -// of subtasks. -func GlobalContextWait(timeout time.Duration) { - done := make(chan struct{}) - after := time.After(timeout) - go func() { - taskWg.Wait() - close(done) - }() - for { - select { - case <-done: - return - case <-after: - logrus.Warnln("Timeout waiting for these tasks to finish:") - tasksMap.Range(func(t *task, _ struct{}) bool { - logrus.Warnln(t.tree()) - return true - }) - return - } - } -} diff --git a/internal/config/config.go b/internal/config/config.go index bc6fcfc..aa318f8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,51 +2,66 @@ package config import ( "os" + "sync" + "time" "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" - PR "github.com/yusing/go-proxy/internal/proxy/provider" - R "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/route" + proxy "github.com/yusing/go-proxy/internal/route/provider" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" - W "github.com/yusing/go-proxy/internal/watcher" + "github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher/events" "gopkg.in/yaml.v3" ) type Config struct { value *types.Config - proxyProviders F.Map[string, *PR.Provider] + providers F.Map[string, *proxy.Provider] autocertProvider *autocert.Provider - - l logrus.FieldLogger - - watcher W.Watcher - - reloadReq chan struct{} + task task.Task } -var instance *Config +var ( + instance *Config + cfgWatcher watcher.Watcher + logger = logrus.WithField("module", "config") + reloadMu sync.Mutex +) + +const configEventFlushInterval = 500 * time.Millisecond + +const ( + cfgRenameWarn = `Config file renamed, not reloading. +Make sure you rename it back before next time you start.` + cfgDeleteWarn = `Config file deleted, not reloading. +You may run "ls-config" to show or dump the current config.` +) func GetInstance() *Config { return instance } -func Load() E.NestedError { +func newConfig() *Config { + return &Config{ + value: types.DefaultConfig(), + providers: F.NewMapOf[string, *proxy.Provider](), + task: task.GlobalTask("config"), + } +} + +func Load() (*Config, E.NestedError) { if instance != nil { - return nil + return instance, nil } - instance = &Config{ - value: types.DefaultConfig(), - proxyProviders: F.NewMapOf[string, *PR.Provider](), - l: logrus.WithField("module", "config"), - watcher: W.NewConfigFileWatcher(common.ConfigFileName), - reloadReq: make(chan struct{}, 1), - } - return instance.load() + instance = newConfig() + cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName) + return instance, instance.load() } func Validate(data []byte) E.NestedError { @@ -54,87 +69,90 @@ func Validate(data []byte) E.NestedError { } func MatchDomains() []string { - if instance == nil { - logrus.Panic("config has not been loaded, please check if there is any errors") - } return instance.value.MatchDomains } -func (cfg *Config) Value() types.Config { - if cfg == nil { - logrus.Panic("config has not been loaded, please check if there is any errors") - } - return *cfg.value +func WatchChanges() { + task := task.GlobalTask("Config watcher") + eventQueue := events.NewEventQueue( + task, + configEventFlushInterval, + OnConfigChange, + func(err E.NestedError) { + logger.Error(err) + }, + ) + eventQueue.Start(cfgWatcher.Events(task.Context())) } -func (cfg *Config) GetAutoCertProvider() *autocert.Provider { - if instance == nil { - logrus.Panic("config has not been loaded, please check if there is any errors") +func OnConfigChange(flushTask task.Task, ev []events.Event) { + defer flushTask.Finish("config reload complete") + + // no matter how many events during the interval + // just reload once and check the last event + switch ev[len(ev)-1].Action { + case events.ActionFileRenamed: + logger.Warn(cfgRenameWarn) + return + case events.ActionFileDeleted: + logger.Warn(cfgDeleteWarn) + return + } + + if err := Reload(); err != nil { + logger.Error(err) } - return cfg.autocertProvider } -func (cfg *Config) Reload() (err E.NestedError) { - cfg.stopProviders() - err = cfg.load() - cfg.StartProxyProviders() - return +func Reload() E.NestedError { + // avoid race between config change and API reload request + reloadMu.Lock() + defer reloadMu.Unlock() + + newCfg := newConfig() + err := newCfg.load() + if err != nil { + return err + } + + // cancel all current subtasks -> wait + // -> replace config -> start new subtasks + instance.task.Finish("config changed") + instance.task.Wait() + *instance = *newCfg + instance.StartProxyProviders() + return nil +} + +func Value() types.Config { + return *instance.value +} + +func GetAutoCertProvider() *autocert.Provider { + return instance.autocertProvider +} + +func (cfg *Config) Task() task.Task { + return cfg.task } func (cfg *Config) StartProxyProviders() { - cfg.controlProviders("start", (*PR.Provider).StartAllRoutes) -} - -func (cfg *Config) WatchChanges() { - task := common.NewTask("Config watcher") - go func() { - defer task.Finished() - for { - select { - case <-task.Context().Done(): - return - case <-cfg.reloadReq: - if err := cfg.Reload(); err != nil { - cfg.l.Error(err) - } - } - } - }() - go func() { - eventCh, errCh := cfg.watcher.Events(task.Context()) - for { - select { - case <-task.Context().Done(): - return - case event := <-eventCh: - if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed { - cfg.l.Error("config file deleted or renamed, ignoring...") - continue - } else { - cfg.reloadReq <- struct{}{} - } - case err := <-errCh: - cfg.l.Error(err) - continue - } - } - }() -} - -func (cfg *Config) forEachRoute(do func(alias string, r *R.Route, p *PR.Provider)) { - cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) { - p.RangeRoutes(func(a string, r *R.Route) { - do(a, r, p) - }) + b := E.NewBuilder("errors starting providers") + cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { + b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName()))) }) + + if b.HasError() { + logger.Error(b.Build()) + } } func (cfg *Config) load() (res E.NestedError) { b := E.NewBuilder("errors loading config") defer b.To(&res) - cfg.l.Debug("loading config") - defer cfg.l.Debug("loaded config") + logger.Debug("loading config") + defer logger.Debug("loaded config") data, err := E.Check(os.ReadFile(common.ConfigPath)) if err != nil { @@ -160,7 +178,7 @@ func (cfg *Config) load() (res E.NestedError) { b.Add(cfg.loadProviders(&model.Providers)) cfg.value = model - R.SetFindMuxDomains(model.MatchDomains) + route.SetFindMuxDomains(model.MatchDomains) return } @@ -169,8 +187,8 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested return } - cfg.l.Debug("initializing autocert") - defer cfg.l.Debug("initialized autocert") + logger.Debug("initializing autocert") + defer logger.Debug("initialized autocert") cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() if err != nil { @@ -179,48 +197,34 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested return } -func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedError) { - cfg.l.Debug("loading providers") - defer cfg.l.Debug("loaded providers") +func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) { + subtask := cfg.task.Subtask("load providers") + defer subtask.Finish("done") - b := E.NewBuilder("errors loading providers") - defer b.To(&res) + errs := E.NewBuilder("errors loading providers") + results := E.NewBuilder("loaded providers") + defer errs.To(&outErr) for _, filename := range providers.Files { - p, err := PR.NewFileProvider(filename) + p, err := proxy.NewFileProvider(filename) if err != nil { - b.Add(err.Subject(filename)) + errs.Add(err) continue } - cfg.proxyProviders.Store(p.GetName(), p) - b.Add(p.LoadRoutes().Subject(filename)) + cfg.providers.Store(p.GetName(), p) + errs.Add(p.LoadRoutes().Subject(filename)) + results.Addf("%d routes from %s", p.NumRoutes(), filename) } for name, dockerHost := range providers.Docker { - p, err := PR.NewDockerProvider(name, dockerHost) + p, err := proxy.NewDockerProvider(name, dockerHost) if err != nil { - b.Add(err.Subjectf("%s (%s)", name, dockerHost)) + errs.Add(err.Subjectf("%s (%s)", name, dockerHost)) continue } - cfg.proxyProviders.Store(p.GetName(), p) - b.Add(p.LoadRoutes().Subject(p.GetName())) + cfg.providers.Store(p.GetName(), p) + errs.Add(p.LoadRoutes().Subject(p.GetName())) + results.Addf("%d routes from %s", p.NumRoutes(), name) } + logger.Info(results.Build()) return } - -func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { - errors := E.NewBuilder("errors in %s these providers", action) - - cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) { - if err := do(p); err != nil { - errors.Add(err.Subject(p)) - } - }) - - if err := errors.Build(); err != nil { - cfg.l.Error(err) - } -} - -func (cfg *Config) stopProviders() { - cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes) -} diff --git a/internal/config/query.go b/internal/config/query.go index 5243169..026512d 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -6,33 +6,35 @@ import ( "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/homepage" - PR "github.com/yusing/go-proxy/internal/proxy/provider" - R "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/proxy/entry" + "github.com/yusing/go-proxy/internal/route" + proxy "github.com/yusing/go-proxy/internal/route/provider" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) -func (cfg *Config) DumpEntries() map[string]*types.RawEntry { - entries := make(map[string]*types.RawEntry) - cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { - entries[alias] = r.Entry +func DumpEntries() map[string]*entry.RawEntry { + entries := make(map[string]*entry.RawEntry) + instance.providers.RangeAll(func(_ string, p *proxy.Provider) { + p.RangeRoutes(func(alias string, r *route.Route) { + entries[alias] = r.Entry + }) }) return entries } -func (cfg *Config) DumpProviders() map[string]*PR.Provider { - entries := make(map[string]*PR.Provider) - cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { +func DumpProviders() map[string]*proxy.Provider { + entries := make(map[string]*proxy.Provider) + instance.providers.RangeAll(func(name string, p *proxy.Provider) { entries[name] = p }) return entries } -func (cfg *Config) HomepageConfig() homepage.Config { +func HomepageConfig() homepage.Config { var proto, port string - domains := cfg.value.MatchDomains - cert, _ := cfg.autocertProvider.GetCert(nil) + domains := instance.value.MatchDomains + cert, _ := instance.autocertProvider.GetCert(nil) if cert != nil { proto = "https" port = common.ProxyHTTPSPort @@ -42,9 +44,9 @@ func (cfg *Config) HomepageConfig() homepage.Config { } hpCfg := homepage.NewHomePageConfig() - R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { - entry := r.Raw - item := entry.Homepage + route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) { + en := r.Raw + item := en.Homepage if item == nil { item = new(homepage.Item) item.Show = true @@ -63,12 +65,12 @@ func (cfg *Config) HomepageConfig() homepage.Config { ) } - if r.IsDocker() { + if entry.IsDocker(r) { if item.Category == "" { item.Category = "Docker" } - item.SourceType = string(PR.ProviderTypeDocker) - } else if r.UseLoadBalance() { + item.SourceType = string(proxy.ProviderTypeDocker) + } else if entry.UseLoadBalance(r) { if item.Category == "" { item.Category = "Load-balanced" } @@ -77,7 +79,7 @@ func (cfg *Config) HomepageConfig() homepage.Config { if item.Category == "" { item.Category = "Others" } - item.SourceType = string(PR.ProviderTypeFile) + item.SourceType = string(proxy.ProviderTypeFile) } if item.URL == "" { @@ -85,26 +87,26 @@ func (cfg *Config) HomepageConfig() homepage.Config { item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port) } } - item.AltURL = r.URL().String() + item.AltURL = r.TargetURL().String() hpCfg.Add(item) }) return hpCfg } -func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any { +func RoutesByAlias(typeFilter ...route.RouteType) map[string]any { routes := make(map[string]any) if len(typeFilter) == 0 || typeFilter[0] == "" { - typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream} + typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream} } for _, t := range typeFilter { switch t { - case R.RouteTypeReverseProxy: - R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { + case route.RouteTypeReverseProxy: + route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) { routes[alias] = r }) - case R.RouteTypeStream: - R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) { + case route.RouteTypeStream: + route.GetStreamProxies().RangeAll(func(alias string, r *route.StreamRoute) { routes[alias] = r }) } @@ -112,12 +114,12 @@ func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any { return routes } -func (cfg *Config) Statistics() map[string]any { +func Statistics() map[string]any { nTotalStreams := 0 nTotalRPs := 0 - providerStats := make(map[string]PR.ProviderStats) + providerStats := make(map[string]proxy.ProviderStats) - cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { + instance.providers.RangeAll(func(name string, p *proxy.Provider) { providerStats[name] = p.Statistics() }) @@ -133,9 +135,9 @@ func (cfg *Config) Statistics() map[string]any { } } -func (cfg *Config) FindRoute(alias string) *R.Route { - return F.MapFind(cfg.proxyProviders, - func(p *PR.Provider) (*R.Route, bool) { +func FindRoute(alias string) *route.Route { + return F.MapFind(instance.providers, + func(p *proxy.Provider) (*route.Route, bool) { if route, ok := p.GetRoute(alias); ok { return route, true } diff --git a/internal/types/autocert_config.go b/internal/config/types/autocert_config.go similarity index 100% rename from internal/types/autocert_config.go rename to internal/config/types/autocert_config.go diff --git a/internal/config/types/config.go b/internal/config/types/config.go new file mode 100644 index 0000000..ed0e638 --- /dev/null +++ b/internal/config/types/config.go @@ -0,0 +1,24 @@ +package types + +type ( + Config struct { + Providers ProxyProviders `json:"providers" yaml:",flow"` + AutoCert AutoCertConfig `json:"autocert" yaml:",flow"` + ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"` + MatchDomains []string `json:"match_domains" yaml:"match_domains"` + TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"` + RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"` + } + ProxyProviders struct { + Files []string `json:"include" yaml:"include"` // docker, file + Docker map[string]string `json:"docker" yaml:"docker"` + } +) + +func DefaultConfig() *Config { + return &Config{ + Providers: ProxyProviders{}, + TimeoutShutdown: 3, + RedirectToHTTPS: false, + } +} diff --git a/internal/docker/client.go b/internal/docker/client.go index c8d9941..570f41a 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -36,22 +37,13 @@ var ( ) func init() { - go func() { - task := common.NewTask("close all docker client") - defer task.Finished() - for { - select { - case <-task.Context().Done(): - clientMap.RangeAllParallel(func(_ string, c Client) { - if c.Connected() { - c.Client.Close() - } - }) - clientMap.Clear() - return + task.GlobalTask("close docker clients").OnComplete("", func() { + clientMap.RangeAllParallel(func(_ string, c Client) { + if c.Connected() { + c.Client.Close() } - } - }() + }) + }) } func (c *SharedClient) Connected() bool { @@ -141,19 +133,10 @@ func ConnectClient(host string) (Client, E.NestedError) { <-c.refCount.Zero() clientMap.Delete(c.key) - if c.Client != nil { + if c.Connected() { c.Client.Close() - c.Client = nil c.l.Debugf("client closed") } }() return c, nil } - -func CloseAllClients() { - clientMap.RangeAllParallel(func(_ string, c Client) { - c.Client.Close() - }) - clientMap.Clear() - logger.Debug("closed all clients") -} diff --git a/internal/docker/client_info.go b/internal/docker/client_info.go index 6228920..751489e 100644 --- a/internal/docker/client_info.go +++ b/internal/docker/client_info.go @@ -2,6 +2,7 @@ package docker import ( "context" + "errors" "time" "github.com/docker/docker/api/types" @@ -16,10 +17,13 @@ type ClientInfo struct { } var listOptions = container.ListOptions{ + // created|restarting|running|removing|paused|exited|dead // Filters: filters.NewArgs( - // filters.Arg("health", "healthy"), - // filters.Arg("health", "none"), - // filters.Arg("health", "starting"), + // filters.Arg("status", "created"), + // filters.Arg("status", "restarting"), + // filters.Arg("status", "running"), + // filters.Arg("status", "paused"), + // filters.Arg("status", "exited"), // ), All: true, } @@ -31,7 +35,7 @@ func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedE } defer dockerClient.Close() - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout")) defer cancel() var containers []types.Container diff --git a/internal/docker/container.go b/internal/docker/container.go index 115e520..d0afa9f 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -32,11 +32,11 @@ type ( IsExcluded bool `json:"is_excluded" yaml:"-"` IsExplicit bool `json:"is_explicit" yaml:"-"` IsDatabase bool `json:"is_database" yaml:"-"` - IdleTimeout string `json:"idle_timeout" yaml:"-"` - WakeTimeout string `json:"wake_timeout" yaml:"-"` - StopMethod string `json:"stop_method" yaml:"-"` - StopTimeout string `json:"stop_timeout" yaml:"-"` // stop_method = "stop" only - StopSignal string `json:"stop_signal" yaml:"-"` // stop_method = "stop" | "kill" only + IdleTimeout string `json:"idle_timeout,omitempty" yaml:"-"` + WakeTimeout string `json:"wake_timeout,omitempty" yaml:"-"` + StopMethod string `json:"stop_method,omitempty" yaml:"-"` + StopTimeout string `json:"stop_timeout,omitempty" yaml:"-"` // stop_method = "stop" only + StopSignal string `json:"stop_signal,omitempty" yaml:"-"` // stop_method = "stop" | "kill" only Running bool `json:"running" yaml:"-"` } ) diff --git a/internal/docker/idlewatcher/config/config.go b/internal/docker/idlewatcher/config/config.go new file mode 100644 index 0000000..2ecaa18 --- /dev/null +++ b/internal/docker/idlewatcher/config/config.go @@ -0,0 +1,112 @@ +package idlewatcher + +import ( + "time" + + "github.com/yusing/go-proxy/internal/docker" + E "github.com/yusing/go-proxy/internal/error" +) + +type ( + Config struct { + IdleTimeout time.Duration `json:"idle_timeout,omitempty"` + WakeTimeout time.Duration `json:"wake_timeout,omitempty"` + StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument + StopMethod StopMethod `json:"stop_method,omitempty"` + StopSignal Signal `json:"stop_signal,omitempty"` + + DockerHost string `json:"docker_host,omitempty"` + ContainerName string `json:"container_name,omitempty"` + ContainerID string `json:"container_id,omitempty"` + ContainerRunning bool `json:"container_running,omitempty"` + } + StopMethod string + Signal string +) + +const ( + StopMethodPause StopMethod = "pause" + StopMethodStop StopMethod = "stop" + StopMethodKill StopMethod = "kill" +) + +func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) { + if cont == nil { + return nil, nil + } + + if cont.IdleTimeout == "" { + return &Config{ + DockerHost: cont.DockerHost, + ContainerName: cont.ContainerName, + ContainerID: cont.ContainerID, + ContainerRunning: cont.Running, + }, nil + } + + b := E.NewBuilder("invalid idlewatcher config") + defer b.To(&res) + + idleTimeout, err := validateDurationPostitive(cont.IdleTimeout) + b.Add(err.Subjectf("%s", "idle_timeout")) + + wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout) + b.Add(err.Subjectf("%s", "wake_timeout")) + + stopTimeout, err := validateDurationPostitive(cont.StopTimeout) + b.Add(err.Subjectf("%s", "stop_timeout")) + + stopMethod, err := validateStopMethod(cont.StopMethod) + b.Add(err) + + signal, err := validateSignal(cont.StopSignal) + b.Add(err) + + if err := b.Build(); err != nil { + return + } + + return &Config{ + IdleTimeout: idleTimeout, + WakeTimeout: wakeTimeout, + StopTimeout: int(stopTimeout.Seconds()), + StopMethod: stopMethod, + StopSignal: signal, + + DockerHost: cont.DockerHost, + ContainerName: cont.ContainerName, + ContainerID: cont.ContainerID, + ContainerRunning: cont.Running, + }, nil +} + +func validateDurationPostitive(value string) (time.Duration, E.NestedError) { + d, err := time.ParseDuration(value) + if err != nil { + return 0, E.Invalid("duration", value).With(err) + } + if d < 0 { + return 0, E.Invalid("duration", "negative value") + } + return d, nil +} + +func validateSignal(s string) (Signal, E.NestedError) { + switch s { + case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", + "INT", "TERM", "HUP", "QUIT": + return Signal(s), nil + } + + return "", E.Invalid("signal", s) +} + +func validateStopMethod(s string) (StopMethod, E.NestedError) { + sm := StopMethod(s) + switch sm { + case StopMethodPause, StopMethodStop, StopMethodKill: + return sm, nil + default: + return "", E.Invalid("stop_method", sm) + } +} diff --git a/internal/docker/idlewatcher/http.go b/internal/docker/idlewatcher/loading_page.go similarity index 69% rename from internal/docker/idlewatcher/http.go rename to internal/docker/idlewatcher/loading_page.go index bb000af..1035bc9 100644 --- a/internal/docker/idlewatcher/http.go +++ b/internal/docker/idlewatcher/loading_page.go @@ -20,16 +20,15 @@ var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(lo const headerCheckRedirect = "X-Goproxy-Check-Redirect" -func (w *Watcher) makeRespBody(format string, args ...any) []byte { - msg := fmt.Sprintf(format, args...) +func (w *Watcher) makeLoadingPageBody() []byte { + msg := fmt.Sprintf("%s is starting...", w.ContainerName) data := new(templateData) data.CheckRedirectHeader = headerCheckRedirect data.Title = w.ContainerName - data.Message = strings.ReplaceAll(msg, "\n", "
") - data.Message = strings.ReplaceAll(data.Message, " ", " ") + data.Message = strings.ReplaceAll(msg, " ", " ") - buf := bytes.NewBuffer(make([]byte, 128)) // more than enough + buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect))) err := loadingPageTmpl.Execute(buf, data) if err != nil { // should never happen in production panic(err) diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 34bbb5e..ddb85c7 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -1,197 +1,133 @@ package idlewatcher import ( - "context" "net/http" - "strconv" + "sync/atomic" "time" - "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/net/types" + net "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/proxy/entry" + "github.com/yusing/go-proxy/internal/task" + U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/watcher/health" ) -type Waker struct { - *Watcher +type Waker interface { + health.HealthMonitor + http.Handler + net.Stream +} + +type waker struct { + _ U.NoCopy - client *http.Client rp *gphttp.ReverseProxy + stream net.Stream + hc health.HealthChecker + + ready atomic.Bool } -func NewWaker(w *Watcher, rp *gphttp.ReverseProxy) *Waker { - return &Waker{ - Watcher: w, - client: &http.Client{ - Timeout: 1 * time.Second, - Transport: rp.Transport, - }, - rp: rp, +const ( + idleWakerCheckInterval = 100 * time.Millisecond + idleWakerCheckTimeout = time.Second +) + +// TODO: support stream + +func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) { + hcCfg := entry.HealthCheckConfig() + hcCfg.Timeout = idleWakerCheckTimeout + + waker := &waker{ + rp: rp, + stream: stream, } -} -func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - shouldNext := w.wake(rw, r) - if !shouldNext { - return + watcher, err := registerWatcher(providerSubTask, entry, waker) + if err != nil { + return nil, err } - w.rp.ServeHTTP(rw, r) + + if rp != nil { + waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport) + } else if stream != nil { + waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg) + } else { + panic("both nil") + } + return watcher, nil } -/* HealthMonitor interface */ - -func (w *Waker) Start() {} - -func (w *Waker) Stop() { - w.Unregister() +// lifetime should follow route provider +func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) { + return newWaker(providerSubTask, entry, rp, nil) } -func (w *Waker) UpdateConfig(config health.HealthCheckConfig) { - panic("use idlewatcher.Register instead") +func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) { + return newWaker(providerSubTask, entry, nil, stream) } -func (w *Waker) Name() string { +// Start implements health.HealthMonitor. +func (w *Watcher) Start(routeSubTask task.Task) E.NestedError { + w.task.OnComplete("stop route", func() { + routeSubTask.Parent().Finish("watcher stopped") + }) + return nil +} + +// Finish implements health.HealthMonitor. +func (w *Watcher) Finish(reason string) {} + +// Name implements health.HealthMonitor. +func (w *Watcher) Name() string { return w.String() } -func (w *Waker) String() string { - return string(w.Alias) +// String implements health.HealthMonitor. +func (w *Watcher) String() string { + return w.ContainerName } -func (w *Waker) Status() health.Status { - if w.ready.Load() { - return health.StatusHealthy - } - if !w.ContainerRunning { - return health.StatusNapping - } - return health.StatusStarting -} - -func (w *Waker) Uptime() time.Duration { +// Uptime implements health.HealthMonitor. +func (w *Watcher) Uptime() time.Duration { return 0 } -func (w *Waker) MarshalJSON() ([]byte, error) { - var url types.URL - if w.URL.String() != "http://:0" { - url = w.URL +// Status implements health.HealthMonitor. +func (w *Watcher) Status() health.Status { + if !w.ContainerRunning { + return health.StatusNapping + } + + if w.ready.Load() { + return health.StatusHealthy + } + + healthy, _, err := w.hc.CheckHealth() + switch { + case err != nil: + return health.StatusError + case healthy: + w.ready.Store(true) + return health.StatusHealthy + default: + return health.StatusStarting + } +} + +// MarshalJSON implements health.HealthMonitor. +func (w *Watcher) MarshalJSON() ([]byte, error) { + var url net.URL + if w.hc.URL().Port() != "0" { + url = w.hc.URL() } return (&health.JSONRepresentation{ Name: w.Name(), Status: w.Status(), - Config: &health.HealthCheckConfig{ - Interval: w.IdleTimeout, - Timeout: w.WakeTimeout, - }, - URL: url, + Config: w.hc.Config(), + URL: url, }).MarshalJSON() } - -/* End of HealthMonitor interface */ - -func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { - w.resetIdleTimer() - - if r.Body != nil { - defer r.Body.Close() - } - - // pass through if container is ready - if w.ready.Load() { - return true - } - - ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout) - defer cancel() - - accept := gphttp.GetAccept(r.Header) - acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) - - isCheckRedirect := r.Header.Get(headerCheckRedirect) != "" - if !isCheckRedirect && acceptHTML { - // Send a loading response to the client - body := w.makeRespBody("%s waking up...", w.ContainerName) - rw.Header().Set("Content-Type", "text/html; charset=utf-8") - rw.Header().Set("Content-Length", strconv.Itoa(len(body))) - rw.Header().Add("Cache-Control", "no-cache") - rw.Header().Add("Cache-Control", "no-store") - rw.Header().Add("Cache-Control", "must-revalidate") - if _, err := rw.Write(body); err != nil { - w.l.Errorf("error writing http response: %s", err) - } - return - } - - select { - case <-w.task.Context().Done(): - http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) - return - case <-ctx.Done(): - http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) - return - default: - } - - w.l.Debug("wake signal received") - err := w.wakeIfStopped() - if err != nil { - w.l.Error(E.FailWith("wake", err)) - http.Error(rw, "Error waking container", http.StatusInternalServerError) - return - } - - // maybe another request came in while we were waiting for the wake - if w.ready.Load() { - if isCheckRedirect { - rw.WriteHeader(http.StatusOK) - return - } - return true - } - - for { - select { - case <-w.task.Context().Done(): - http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) - return - case <-ctx.Done(): - http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) - return - default: - } - - wakeReq, err := http.NewRequestWithContext( - ctx, - http.MethodHead, - w.URL.String(), - nil, - ) - if err != nil { - w.l.Errorf("new request err to %s: %s", r.URL, err) - http.Error(rw, "Internal server error", http.StatusInternalServerError) - return - } - - wakeResp, err := w.client.Do(wakeReq) - if err == nil && wakeResp.StatusCode != http.StatusServiceUnavailable { - w.ready.Store(true) - w.l.Debug("awaken") - if isCheckRedirect { - rw.WriteHeader(http.StatusOK) - return - } - logrus.Infof("container %s is ready, passing through to %s", w.Alias, w.rp.TargetURL) - return true - } - - // retry until the container is ready or timeout - time.Sleep(100 * time.Millisecond) - } -} - -// static HealthMonitor interface check -func (w *Waker) _() health.HealthMonitor { - return w -} diff --git a/internal/docker/idlewatcher/waker_http.go b/internal/docker/idlewatcher/waker_http.go new file mode 100644 index 0000000..3333280 --- /dev/null +++ b/internal/docker/idlewatcher/waker_http.go @@ -0,0 +1,105 @@ +package idlewatcher + +import ( + "context" + "errors" + "net/http" + "strconv" + "time" + + "github.com/sirupsen/logrus" + E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +// ServeHTTP implements http.Handler +func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + shouldNext := w.wakeFromHTTP(rw, r) + if !shouldNext { + return + } + w.rp.ServeHTTP(rw, r) +} + +func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { + w.resetIdleTimer() + + if r.Body != nil { + defer r.Body.Close() + } + + // pass through if container is already ready + if w.ready.Load() { + return true + } + + accept := gphttp.GetAccept(r.Header) + acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) + + isCheckRedirect := r.Header.Get(headerCheckRedirect) != "" + if !isCheckRedirect && acceptHTML { + // Send a loading response to the client + body := w.makeLoadingPageBody() + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Header().Set("Content-Length", strconv.Itoa(len(body))) + rw.Header().Add("Cache-Control", "no-cache") + rw.Header().Add("Cache-Control", "no-store") + rw.Header().Add("Cache-Control", "must-revalidate") + rw.Header().Add("Connection", "close") + if _, err := rw.Write(body); err != nil { + w.l.Errorf("error writing http response: %s", err) + } + return + } + + ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout")) + defer cancel() + + checkCancelled := func() bool { + select { + case <-w.task.Context().Done(): + w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context())) + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) + return true + case <-ctx.Done(): + w.l.Debugf("wake cancelled: %s", context.Cause(ctx)) + http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) + return true + default: + return false + } + } + + if checkCancelled() { + return false + } + + w.l.Debug("wake signal received") + err := w.wakeIfStopped() + if err != nil { + w.l.Error(E.FailWith("wake", err)) + http.Error(rw, "Error waking container", http.StatusInternalServerError) + return + } + + for { + if checkCancelled() { + return false + } + + if w.Status() == health.StatusHealthy { + w.resetIdleTimer() + if isCheckRedirect { + logrus.Debugf("container %s is ready, redirecting...", w.String()) + rw.WriteHeader(http.StatusOK) + return + } + logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) + return true + } + + // retry until the container is ready or timeout + time.Sleep(idleWakerCheckInterval) + } +} diff --git a/internal/docker/idlewatcher/waker_stream.go b/internal/docker/idlewatcher/waker_stream.go new file mode 100644 index 0000000..326ebeb --- /dev/null +++ b/internal/docker/idlewatcher/waker_stream.go @@ -0,0 +1,87 @@ +package idlewatcher + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +// Setup implements types.Stream. +func (w *Watcher) Setup() error { + return w.stream.Setup() +} + +// Accept implements types.Stream. +func (w *Watcher) Accept() (conn types.StreamConn, err error) { + conn, err = w.stream.Accept() + // timeout means no connection is accepted + var nErr *net.OpError + ok := errors.As(err, &nErr) + if ok && nErr.Timeout() { + return + } + if err := w.wakeFromStream(); err != nil { + return nil, err + } + return w.stream.Accept() +} + +// CloseListeners implements types.Stream. +func (w *Watcher) CloseListeners() { + w.stream.CloseListeners() +} + +// Handle implements types.Stream. +func (w *Watcher) Handle(conn types.StreamConn) error { + if err := w.wakeFromStream(); err != nil { + return err + } + return w.stream.Handle(conn) +} + +func (w *Watcher) wakeFromStream() error { + // pass through if container is already ready + if w.ready.Load() { + return nil + } + + w.l.Debug("wake signal received") + wakeErr := w.wakeIfStopped() + if wakeErr != nil { + wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr) + w.l.Error(wakeErr) + return wakeErr + } + + ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.WakeTimeout, errors.New("wake timeout")) + defer cancel() + + for { + select { + case <-w.task.Context().Done(): + cause := w.task.FinishCause() + w.l.Debugf("wake cancelled: %s", cause) + return cause + case <-ctx.Done(): + cause := context.Cause(ctx) + w.l.Debugf("wake cancelled: %s", cause) + return cause + default: + } + + if w.Status() == health.StatusHealthy { + w.resetIdleTimer() + logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) + return nil + } + + // retry until the container is ready or timeout + time.Sleep(idleWakerCheckInterval) + } +} diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index 7ad67ce..1997d5a 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -2,191 +2,193 @@ package idlewatcher import ( "context" + "errors" + "fmt" "sync" - "sync/atomic" "time" "github.com/docker/docker/api/types/container" "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" E "github.com/yusing/go-proxy/internal/error" - P "github.com/yusing/go-proxy/internal/proxy" - PT "github.com/yusing/go-proxy/internal/proxy/fields" + "github.com/yusing/go-proxy/internal/proxy/entry" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher" + "github.com/yusing/go-proxy/internal/watcher/events" ) type ( Watcher struct { - *P.ReverseProxyEntry + _ U.NoCopy - client D.Client + *idlewatcher.Config + *waker - ready atomic.Bool // whether the site is ready to accept connection + client D.Client stopByMethod StopCallback // send a docker command w.r.t. `stop_method` - - ticker *time.Ticker - - task common.Task - cancel context.CancelFunc - - refCount *U.RefCount - - l logrus.FieldLogger + ticker *time.Ticker + task task.Task + l *logrus.Entry } WakeDone <-chan error WakeFunc func() WakeDone - StopCallback func() E.NestedError + StopCallback func() error ) var ( watcherMap = F.NewMapOf[string, *Watcher]() watcherMapMu sync.Mutex - portHistoryMap = F.NewMapOf[PT.Alias, string]() - logger = logrus.WithField("module", "idle_watcher") ) -func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) { - failure := E.Failure("idle_watcher register") +const dockerReqTimeout = 3 * time.Second - if entry.IdleTimeout == 0 { - return nil, failure.With(E.Invalid("idle_timeout", 0)) +func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) { + failure := E.Failure("idle_watcher register") + cfg := entry.IdlewatcherConfig() + + if cfg.IdleTimeout == 0 { + panic("should not reach here") } watcherMapMu.Lock() defer watcherMapMu.Unlock() - key := entry.ContainerID - - if entry.URL.Port() != "0" { - portHistoryMap.Store(entry.Alias, entry.URL.Port()) - } + key := cfg.ContainerID if w, ok := watcherMap.Load(key); ok { - w.refCount.Add() - w.ReverseProxyEntry = entry + w.Config = cfg + w.waker = waker + w.resetIdleTimer() return w, nil } - client, err := D.ConnectClient(entry.DockerHost) + client, err := D.ConnectClient(cfg.DockerHost) if err.HasError() { return nil, failure.With(err) } w := &Watcher{ - ReverseProxyEntry: entry, - client: client, - refCount: U.NewRefCounter(), - ticker: time.NewTicker(entry.IdleTimeout), - l: logger.WithField("container", entry.ContainerName), + Config: cfg, + waker: waker, + client: client, + task: providerSubtask, + ticker: time.NewTicker(cfg.IdleTimeout), + l: logger.WithField("container", cfg.ContainerName), } - w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias) w.stopByMethod = w.getStopCallback() - watcherMap.Store(key, w) - go w.watchUntilCancel() + go func() { + cause := w.watchUntilDestroy() + + watcherMapMu.Lock() + watcherMap.Delete(w.ContainerID) + watcherMapMu.Unlock() + + w.ticker.Stop() + w.client.Close() + w.task.Finish(cause.Error()) + }() return w, nil } -func (w *Watcher) Unregister() { - w.refCount.Sub() -} - -func (w *Watcher) containerStop() error { - return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{ +func (w *Watcher) containerStop(ctx context.Context) error { + return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{ Signal: string(w.StopSignal), Timeout: &w.StopTimeout, }) } -func (w *Watcher) containerPause() error { - return w.client.ContainerPause(w.task.Context(), w.ContainerID) +func (w *Watcher) containerPause(ctx context.Context) error { + return w.client.ContainerPause(ctx, w.ContainerID) } -func (w *Watcher) containerKill() error { - return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal)) +func (w *Watcher) containerKill(ctx context.Context) error { + return w.client.ContainerKill(ctx, w.ContainerID, string(w.StopSignal)) } -func (w *Watcher) containerUnpause() error { - return w.client.ContainerUnpause(w.task.Context(), w.ContainerID) +func (w *Watcher) containerUnpause(ctx context.Context) error { + return w.client.ContainerUnpause(ctx, w.ContainerID) } -func (w *Watcher) containerStart() error { - return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{}) +func (w *Watcher) containerStart(ctx context.Context) error { + return w.client.ContainerStart(ctx, w.ContainerID, container.StartOptions{}) } -func (w *Watcher) containerStatus() (string, E.NestedError) { +func (w *Watcher) containerStatus() (string, error) { if !w.client.Connected() { - return "", E.Failure("docker client closed") + return "", errors.New("docker client not connected") } - json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID) + ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout")) + defer cancel() + json, err := w.client.ContainerInspect(ctx, w.ContainerID) if err != nil { - return "", E.FailWith("inspect container", err) + return "", fmt.Errorf("failed to inspect container: %w", err) } return json.State.Status, nil } -func (w *Watcher) wakeIfStopped() E.NestedError { - if w.ready.Load() || w.ContainerRunning { +func (w *Watcher) wakeIfStopped() error { + if w.ContainerRunning { return nil } status, err := w.containerStatus() - - if err.HasError() { + if err != nil { return err } - // "created", "running", "paused", "restarting", "removing", "exited", or "dead" + + ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout) + defer cancel() + + // !Hard coded here since theres no constants from Docker API switch status { case "exited", "dead": - return E.From(w.containerStart()) + return w.containerStart(ctx) case "paused": - return E.From(w.containerUnpause()) + return w.containerUnpause(ctx) case "running": return nil default: - return E.Unexpected("container state", status) + panic("should not reach here") } } func (w *Watcher) getStopCallback() StopCallback { - var cb func() error + var cb func(context.Context) error switch w.StopMethod { - case PT.StopMethodPause: + case idlewatcher.StopMethodPause: cb = w.containerPause - case PT.StopMethodStop: + case idlewatcher.StopMethodStop: cb = w.containerStop - case PT.StopMethodKill: + case idlewatcher.StopMethodKill: cb = w.containerKill default: panic("should not reach here") } - return func() E.NestedError { - status, err := w.containerStatus() - if err.HasError() { - return err - } - if status != "running" { - return nil - } - return E.From(cb()) + return func() error { + ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout) + defer cancel() + return cb(ctx) } } func (w *Watcher) resetIdleTimer() { + w.l.Trace("reset idle timer") w.ticker.Reset(w.IdleTimeout) } -func (w *Watcher) watchUntilCancel() { - dockerWatcher := W.NewDockerWatcherWithClient(w.client) - dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{ +func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) { + eventTask = w.task.Subtask("watcher for %s", w.ContainerID) + eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ Filters: W.NewDockerFilter( W.DockerFilterContainer, W.DockerrFilterContainer(w.ContainerID), @@ -194,34 +196,47 @@ func (w *Watcher) watchUntilCancel() { W.DockerFilterStop, W.DockerFilterDie, W.DockerFilterKill, + W.DockerFilterDestroy, W.DockerFilterPause, W.DockerFilterUnpause, ), }) + return +} - defer func() { - w.cancel() - w.ticker.Stop() - w.client.Close() - watcherMap.Delete(w.ContainerID) - w.task.Finished() - }() +// watchUntilDestroy waits for the container to be created, started, or unpaused, +// and then reset the idle timer. +// +// When the container is stopped, paused, +// or killed, the idle timer is stopped and the ContainerRunning flag is set to false. +// +// When the idle timer fires, the container is stopped according to the +// stop method. +// +// it exits only if the context is canceled, the container is destroyed, +// errors occured on docker client, or route provider died (mainly caused by config reload). +func (w *Watcher) watchUntilDestroy() error { + dockerWatcher := W.NewDockerWatcherWithClient(w.client) + eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) for { select { case <-w.task.Context().Done(): - w.l.Debug("stopped by context done") - return - case <-w.refCount.Zero(): - w.l.Debug("stopped by zero ref count") - return + cause := context.Cause(w.task.Context()) + w.l.Debugf("watcher stopped by context done: %s", cause) + return cause case err := <-dockerEventErrCh: if err != nil && err.IsNot(context.Canceled) { w.l.Error(E.FailWith("docker watcher", err)) - return + return err.Error() } case e := <-dockerEventCh: switch { + case e.Action == events.ActionContainerDestroy: + w.ContainerRunning = false + w.ready.Store(false) + w.l.Info("watcher stopped by container destruction") + return errors.New("container destroyed") // create / start / unpause case e.Action.IsContainerWake(): w.ContainerRunning = true @@ -229,18 +244,31 @@ func (w *Watcher) watchUntilCancel() { w.l.Info("container awaken") case e.Action.IsContainerSleep(): // stop / pause / kil w.ContainerRunning = false - w.ticker.Stop() w.ready.Store(false) + w.ticker.Stop() default: w.l.Errorf("unexpected docker event: %s", e) } + // container name changed should also change the container id + if w.ContainerName != e.ActorName { + w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName) + w.ContainerName = e.ActorName + } + if w.ContainerID != e.ActorID { + w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID) + w.ContainerID = e.ActorID + // recreate event stream + eventTask.Finish("recreate event stream") + eventTask, dockerEventCh, dockerEventErrCh = w.getEventCh(dockerWatcher) + } case <-w.ticker.C: - w.l.Debug("idle timeout") w.ticker.Stop() - if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) { - w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod)) - } else { - w.l.Info("stopped by idle timeout") + if w.ContainerRunning { + if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) { + w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err) + } else { + w.l.Info("container stopped by idle timeout") + } } } } diff --git a/internal/docker/inspect.go b/internal/docker/inspect.go index ae277ee..9d8d854 100644 --- a/internal/docker/inspect.go +++ b/internal/docker/inspect.go @@ -2,6 +2,7 @@ package docker import ( "context" + "errors" "time" E "github.com/yusing/go-proxy/internal/error" @@ -19,7 +20,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) } func (c Client) Inspect(containerID string) (*Container, E.NestedError) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout")) defer cancel() json, err := c.ContainerInspect(ctx, containerID) diff --git a/internal/error/builder.go b/internal/error/builder.go index e8e849d..e0d866b 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -17,35 +17,36 @@ type builder struct { } func NewBuilder(format string, args ...any) Builder { - return Builder{&builder{message: fmt.Sprintf(format, args...)}} + if len(args) > 0 { + return Builder{&builder{message: fmt.Sprintf(format, args...)}} + } + return Builder{&builder{message: format}} } // adding nil / nil is no-op, // you may safely pass expressions returning error to it. -func (b Builder) Add(err NestedError) Builder { +func (b Builder) Add(err NestedError) { if err != nil { b.Lock() b.errors = append(b.errors, err) b.Unlock() } - return b } -func (b Builder) AddE(err error) Builder { - return b.Add(From(err)) +func (b Builder) AddE(err error) { + b.Add(From(err)) } -func (b Builder) Addf(format string, args ...any) Builder { - return b.Add(errorf(format, args...)) +func (b Builder) Addf(format string, args ...any) { + b.Add(errorf(format, args...)) } -func (b Builder) AddRangeE(errs ...error) Builder { +func (b Builder) AddRangeE(errs ...error) { b.Lock() defer b.Unlock() for _, err := range errs { b.AddE(err) } - return b } // Build builds a NestedError based on the errors collected in the Builder. diff --git a/internal/error/errors.go b/internal/error/errors.go index 8728c73..896b108 100644 --- a/internal/error/errors.go +++ b/internal/error/errors.go @@ -2,6 +2,7 @@ package error import ( stderrors "errors" + "fmt" "reflect" ) @@ -16,6 +17,7 @@ var ( ErrOutOfRange = stderrors.New("out of range") ErrTypeError = stderrors.New("type error") ErrTypeMismatch = stderrors.New("type mismatch") + ErrPanicRecv = stderrors.New("panic") ) const fmtSubjectWhat = "%w %v: %q" @@ -75,3 +77,7 @@ func TypeError2(subject any, from, to reflect.Value) NestedError { func TypeMismatch[Expect any](value any) NestedError { return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) } + +func PanicRecv(format string, args ...any) NestedError { + return errorf("%w%s", ErrPanicRecv, fmt.Sprintf(format, args...)) +} diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index da32778..447420f 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -4,18 +4,20 @@ import ( "hash/fnv" "net" "net/http" + "sync" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" ) type ipHash struct { - *LoadBalancer realIP *middleware.Middleware + pool servers + mu sync.Mutex } func (lb *LoadBalancer) newIPHash() impl { - impl := &ipHash{LoadBalancer: lb} + impl := new(ipHash) if len(lb.Options) == 0 { return impl } @@ -26,10 +28,37 @@ func (lb *LoadBalancer) newIPHash() impl { } return impl } -func (ipHash) OnAddServer(srv *Server) {} -func (ipHash) OnRemoveServer(srv *Server) {} -func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) { +func (impl *ipHash) OnAddServer(srv *Server) { + impl.mu.Lock() + defer impl.mu.Unlock() + + for i, s := range impl.pool { + if s == srv { + return + } + if s == nil { + impl.pool[i] = srv + return + } + } + + impl.pool = append(impl.pool, srv) +} + +func (impl *ipHash) OnRemoveServer(srv *Server) { + impl.mu.Lock() + defer impl.mu.Unlock() + + for i, s := range impl.pool { + if s == srv { + impl.pool[i] = nil + return + } + } +} + +func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) { if impl.realIP != nil { impl.realIP.ModifyRequest(impl.serveHTTP, rw, r) } else { @@ -37,7 +66,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) } } -func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { +func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { http.Error(rw, "Internal error", http.StatusInternalServerError) @@ -45,10 +74,12 @@ func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { return } idx := hashIP(ip) % uint32(len(impl.pool)) - if impl.pool[idx].Status().Bad() { + + srv := impl.pool[idx] + if srv == nil || srv.Status().Bad() { http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) } - impl.pool[idx].ServeHTTP(rw, r) + srv.ServeHTTP(rw, r) } func hashIP(ip string) uint32 { diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index 0d3bee1..c7c6cc0 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -5,8 +5,9 @@ import ( "sync" "time" - "github.com/go-acme/lego/v4/log" + E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -28,7 +29,9 @@ type ( impl *Config - pool servers + task task.Task + + pool Pool poolMu sync.Mutex sumWeight weightType @@ -41,11 +44,35 @@ type ( const maxWeight weightType = 100 func New(cfg *Config) *LoadBalancer { - lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)} + lb := &LoadBalancer{ + Config: new(Config), + pool: newPool(), + task: task.DummyTask(), + } lb.UpdateConfigIfNeeded(cfg) return lb } +// Start implements task.TaskStarter. +func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError { + lb.startTime = time.Now() + lb.task = routeSubtask + lb.task.OnComplete("loadbalancer cleanup", func() { + if lb.impl != nil { + lb.pool.RangeAll(func(k string, v *Server) { + lb.impl.OnRemoveServer(v) + }) + } + lb.pool.Clear() + }) + return nil +} + +// Finish implements task.TaskFinisher. +func (lb *LoadBalancer) Finish(reason string) { + lb.task.Finish(reason) +} + func (lb *LoadBalancer) updateImpl() { switch lb.Mode { case Unset, RoundRobin: @@ -57,9 +84,9 @@ func (lb *LoadBalancer) updateImpl() { default: // should happen in test only lb.impl = lb.newRoundRobin() } - for _, srv := range lb.pool { + lb.pool.RangeAll(func(_ string, srv *Server) { lb.impl.OnAddServer(srv) - } + }) } func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { @@ -91,55 +118,60 @@ func (lb *LoadBalancer) AddServer(srv *Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() - lb.pool = append(lb.pool, srv) + if lb.pool.Has(srv.Name) { + old, _ := lb.pool.Load(srv.Name) + lb.sumWeight -= old.Weight + lb.impl.OnRemoveServer(old) + } + lb.pool.Store(srv.Name, srv) lb.sumWeight += srv.Weight - lb.Rebalance() + lb.rebalance() lb.impl.OnAddServer(srv) - logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool)) + logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size()) } func (lb *LoadBalancer) RemoveServer(srv *Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() - lb.sumWeight -= srv.Weight - lb.Rebalance() - lb.impl.OnRemoveServer(srv) - - for i, s := range lb.pool { - if s == srv { - lb.pool = append(lb.pool[:i], lb.pool[i+1:]...) - break - } - } - if lb.IsEmpty() { - lb.pool = nil + if !lb.pool.Has(srv.Name) { return } - logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool)) + lb.pool.Delete(srv.Name) + + lb.sumWeight -= srv.Weight + lb.rebalance() + lb.impl.OnRemoveServer(srv) + + if lb.pool.Size() == 0 { + lb.task.Finish("no server left") + logger.Infof("[remove] loadbalancer %s stopped", lb.Link) + return + } + + logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size()) } -func (lb *LoadBalancer) IsEmpty() bool { - return len(lb.pool) == 0 -} - -func (lb *LoadBalancer) Rebalance() { +func (lb *LoadBalancer) rebalance() { if lb.sumWeight == maxWeight { return } + if lb.pool.Size() == 0 { + return + } if lb.sumWeight == 0 { // distribute evenly - weightEach := maxWeight / weightType(len(lb.pool)) - remainder := maxWeight % weightType(len(lb.pool)) - for _, s := range lb.pool { + weightEach := maxWeight / weightType(lb.pool.Size()) + remainder := maxWeight % weightType(lb.pool.Size()) + lb.pool.RangeAll(func(_ string, s *Server) { s.Weight = weightEach lb.sumWeight += weightEach if remainder > 0 { s.Weight++ remainder-- } - } + }) return } @@ -147,18 +179,18 @@ func (lb *LoadBalancer) Rebalance() { scaleFactor := float64(maxWeight) / float64(lb.sumWeight) lb.sumWeight = 0 - for _, s := range lb.pool { + lb.pool.RangeAll(func(_ string, s *Server) { s.Weight = weightType(float64(s.Weight) * scaleFactor) lb.sumWeight += s.Weight - } + }) delta := maxWeight - lb.sumWeight if delta == 0 { return } - for _, s := range lb.pool { + lb.pool.Range(func(_ string, s *Server) bool { if delta == 0 { - break + return false } if delta > 0 { s.Weight++ @@ -169,7 +201,8 @@ func (lb *LoadBalancer) Rebalance() { lb.sumWeight-- delta++ } - } + return true + }) } func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { @@ -181,23 +214,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { lb.impl.ServeHTTP(srvs, rw, r) } -func (lb *LoadBalancer) Start() { - if lb.sumWeight != 0 { - log.Warnf("weighted mode not supported yet") - } - - lb.startTime = time.Now() - logger.Debugf("loadbalancer %s started", lb.Link) -} - -func (lb *LoadBalancer) Stop() { - lb.poolMu.Lock() - defer lb.poolMu.Unlock() - lb.pool = nil - - logger.Debugf("loadbalancer %s stopped", lb.Link) -} - func (lb *LoadBalancer) Uptime() time.Duration { return time.Since(lb.startTime) } @@ -205,9 +221,10 @@ func (lb *LoadBalancer) Uptime() time.Duration { // MarshalJSON implements health.HealthMonitor. func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { extra := make(map[string]any) - for _, v := range lb.pool { + lb.pool.RangeAll(func(k string, v *Server) { extra[v.Name] = v.healthMon - } + }) + return (&health.JSONRepresentation{ Name: lb.Name(), Status: lb.Status(), @@ -227,7 +244,7 @@ func (lb *LoadBalancer) Name() string { // Status implements health.HealthMonitor. func (lb *LoadBalancer) Status() health.Status { - if len(lb.pool) == 0 { + if lb.pool.Size() == 0 { return health.StatusUnknown } if len(lb.availServers()) == 0 { @@ -241,21 +258,13 @@ func (lb *LoadBalancer) String() string { return lb.Name() } -func (lb *LoadBalancer) availServers() servers { - lb.poolMu.Lock() - defer lb.poolMu.Unlock() - - avail := make(servers, 0, len(lb.pool)) - for _, s := range lb.pool { - if s.Status().Bad() { - continue +func (lb *LoadBalancer) availServers() []*Server { + avail := make([]*Server, 0, lb.pool.Size()) + lb.pool.RangeAll(func(_ string, srv *Server) { + if srv.Status().Bad() { + return } - avail = append(avail, s) - } + avail = append(avail, srv) + }) return avail } - -// static HealthMonitor interface check -func (lb *LoadBalancer) _() health.HealthMonitor { - return lb -} diff --git a/internal/net/http/loadbalancer/loadbalancer_test.go b/internal/net/http/loadbalancer/loadbalancer_test.go index 7b1a043..b180c6d 100644 --- a/internal/net/http/loadbalancer/loadbalancer_test.go +++ b/internal/net/http/loadbalancer/loadbalancer_test.go @@ -13,7 +13,7 @@ func TestRebalance(t *testing.T) { for range 10 { lb.AddServer(&Server{}) } - lb.Rebalance() + lb.rebalance() ExpectEqual(t, lb.sumWeight, maxWeight) }) t.Run("less", func(t *testing.T) { @@ -23,7 +23,7 @@ func TestRebalance(t *testing.T) { lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) - lb.Rebalance() + lb.rebalance() // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) ExpectEqual(t, lb.sumWeight, maxWeight) }) @@ -36,7 +36,7 @@ func TestRebalance(t *testing.T) { lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) - lb.Rebalance() + lb.rebalance() // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) ExpectEqual(t, lb.sumWeight, maxWeight) }) diff --git a/internal/net/http/loadbalancer/server.go b/internal/net/http/loadbalancer/server.go index 45a02d3..f8a9623 100644 --- a/internal/net/http/loadbalancer/server.go +++ b/internal/net/http/loadbalancer/server.go @@ -6,6 +6,7 @@ import ( "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" + F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -20,9 +21,12 @@ type ( handler http.Handler healthMon health.HealthMonitor } - servers []*Server + servers = []*Server + Pool = F.Map[string, *Server] ) +var newPool = F.NewMap[Pool] + func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server { srv := &Server{ Name: name, diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index 5defadf..d2b760a 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -48,11 +48,11 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, } delete(def, "use") m, err := base.WithOptionsClone(def) - m.name = fmt.Sprintf("%s[%d]", name, i) if err != nil { chainErr.Add(err.Subjectf("item%d", i)) continue } + m.name = fmt.Sprintf("%s[%d]", name, i) chain = append(chain, m) } if chainErr.HasError() { diff --git a/internal/net/types/stream.go b/internal/net/types/stream.go new file mode 100644 index 0000000..6306089 --- /dev/null +++ b/internal/net/types/stream.go @@ -0,0 +1,19 @@ +package types + +import ( + "fmt" + "net" +) + +type Stream interface { + fmt.Stringer + Setup() error + Accept() (conn StreamConn, err error) + Handle(conn StreamConn) error + CloseListeners() +} + +type StreamConn interface { + RemoteAddr() net.Addr + Close() error +} diff --git a/internal/proxy/entry.go b/internal/proxy/entry.go deleted file mode 100644 index 6801c06..0000000 --- a/internal/proxy/entry.go +++ /dev/null @@ -1,177 +0,0 @@ -package proxy - -import ( - "fmt" - "net/url" - "time" - - D "github.com/yusing/go-proxy/internal/docker" - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/net/http/loadbalancer" - net "github.com/yusing/go-proxy/internal/net/types" - T "github.com/yusing/go-proxy/internal/proxy/fields" - "github.com/yusing/go-proxy/internal/types" - "github.com/yusing/go-proxy/internal/watcher/health" -) - -type ( - ReverseProxyEntry struct { // real model after validation - Raw *types.RawEntry `json:"raw"` - - Alias T.Alias `json:"alias,omitempty"` - Scheme T.Scheme `json:"scheme,omitempty"` - URL net.URL `json:"url,omitempty"` - NoTLSVerify bool `json:"no_tls_verify,omitempty"` - PathPatterns T.PathPatterns `json:"path_patterns,omitempty"` - HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` - LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"` - Middlewares D.NestedLabelMap `json:"middlewares,omitempty"` - - /* Docker only */ - IdleTimeout time.Duration `json:"idle_timeout,omitempty"` - WakeTimeout time.Duration `json:"wake_timeout,omitempty"` - StopMethod T.StopMethod `json:"stop_method,omitempty"` - StopTimeout int `json:"stop_timeout,omitempty"` - StopSignal T.Signal `json:"stop_signal,omitempty"` - DockerHost string `json:"docker_host,omitempty"` - ContainerName string `json:"container_name,omitempty"` - ContainerID string `json:"container_id,omitempty"` - ContainerRunning bool `json:"container_running,omitempty"` - } - StreamEntry struct { - Raw *types.RawEntry `json:"raw"` - - Alias T.Alias `json:"alias,omitempty"` - Scheme T.StreamScheme `json:"scheme,omitempty"` - Host T.Host `json:"host,omitempty"` - Port T.StreamPort `json:"port,omitempty"` - Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` - } -) - -func (rp *ReverseProxyEntry) UseIdleWatcher() bool { - return rp.IdleTimeout > 0 && rp.IsDocker() -} - -func (rp *ReverseProxyEntry) UseLoadBalance() bool { - return rp.LoadBalance != nil && rp.LoadBalance.Link != "" -} - -func (rp *ReverseProxyEntry) IsDocker() bool { - return rp.DockerHost != "" -} - -func (rp *ReverseProxyEntry) IsZeroPort() bool { - return rp.URL.Port() == "0" -} - -func (rp *ReverseProxyEntry) ShouldNotServe() bool { - return rp.IsZeroPort() && !rp.UseIdleWatcher() -} - -func ValidateEntry(m *types.RawEntry) (any, E.NestedError) { - m.FillMissingFields() - - scheme, err := T.NewScheme(m.Scheme) - if err != nil { - return nil, err - } - - var entry any - e := E.NewBuilder("error validating entry") - if scheme.IsStream() { - entry = validateStreamEntry(m, e) - } else { - entry = validateRPEntry(m, scheme, e) - } - if err := e.Build(); err != nil { - return nil, err - } - return entry, nil -} - -func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry { - var stopTimeOut time.Duration - cont := m.Container - if cont == nil { - cont = D.DummyContainer - } - - host, err := T.ValidateHost(m.Host) - b.Add(err) - - port, err := T.ValidatePort(m.Port) - b.Add(err) - - pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns) - b.Add(err) - - url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) - b.Add(err) - - idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout) - b.Add(err) - - wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout) - b.Add(err) - - stopMethod, err := T.ValidateStopMethod(cont.StopMethod) - b.Add(err) - - if stopMethod == T.StopMethodStop { - stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout) - b.Add(err) - } - - stopSignal, err := T.ValidateSignal(cont.StopSignal) - b.Add(err) - - if err != nil { - return nil - } - - return &ReverseProxyEntry{ - Raw: m, - Alias: T.NewAlias(m.Alias), - Scheme: s, - URL: net.NewURL(url), - NoTLSVerify: m.NoTLSVerify, - PathPatterns: pathPatterns, - HealthCheck: &m.HealthCheck, - LoadBalance: &m.LoadBalance, - Middlewares: m.Middlewares, - IdleTimeout: idleTimeout, - WakeTimeout: wakeTimeout, - StopMethod: stopMethod, - StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument - StopSignal: stopSignal, - DockerHost: cont.DockerHost, - ContainerName: cont.ContainerName, - ContainerID: cont.ContainerID, - ContainerRunning: cont.Running, - } -} - -func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry { - host, err := T.ValidateHost(m.Host) - b.Add(err) - - port, err := T.ValidateStreamPort(m.Port) - b.Add(err) - - scheme, err := T.ValidateStreamScheme(m.Scheme) - b.Add(err) - - if b.HasError() { - return nil - } - - return &StreamEntry{ - Raw: m, - Alias: T.NewAlias(m.Alias), - Scheme: *scheme, - Host: host, - Port: port, - Healthcheck: &m.HealthCheck, - } -} diff --git a/internal/proxy/entry/entry.go b/internal/proxy/entry/entry.go new file mode 100644 index 0000000..af69ba4 --- /dev/null +++ b/internal/proxy/entry/entry.go @@ -0,0 +1,68 @@ +package entry + +import ( + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer" + net "github.com/yusing/go-proxy/internal/net/types" + T "github.com/yusing/go-proxy/internal/proxy/fields" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +type Entry interface { + TargetName() string + TargetURL() net.URL + RawEntry() *RawEntry + LoadBalanceConfig() *loadbalancer.Config + HealthCheckConfig() *health.HealthCheckConfig + IdlewatcherConfig() *idlewatcher.Config +} + +func ValidateEntry(m *RawEntry) (Entry, E.NestedError) { + m.FillMissingFields() + + scheme, err := T.NewScheme(m.Scheme) + if err != nil { + return nil, err + } + + var entry Entry + e := E.NewBuilder("error validating entry") + if scheme.IsStream() { + entry = validateStreamEntry(m, e) + } else { + entry = validateRPEntry(m, scheme, e) + } + if err := e.Build(); err != nil { + return nil, err + } + return entry, nil +} + +func IsDocker(entry Entry) bool { + iw := entry.IdlewatcherConfig() + return iw != nil && iw.ContainerID != "" +} + +func IsZeroPort(entry Entry) bool { + return entry.TargetURL().Port() == "0" +} + +func ShouldNotServe(entry Entry) bool { + return IsZeroPort(entry) && !UseIdleWatcher(entry) +} + +func UseLoadBalance(entry Entry) bool { + lb := entry.LoadBalanceConfig() + return lb != nil && lb.Link != "" +} + +func UseIdleWatcher(entry Entry) bool { + iw := entry.IdlewatcherConfig() + return iw != nil && iw.IdleTimeout > 0 +} + +func UseHealthCheck(entry Entry) bool { + hc := entry.HealthCheckConfig() + return hc != nil && !hc.Disabled +} diff --git a/internal/types/raw_entry.go b/internal/proxy/entry/raw.go similarity index 68% rename from internal/types/raw_entry.go rename to internal/proxy/entry/raw.go index 9d8a56d..26f5161 100644 --- a/internal/types/raw_entry.go +++ b/internal/proxy/entry/raw.go @@ -1,4 +1,4 @@ -package types +package entry import ( "strconv" @@ -21,16 +21,16 @@ type ( // raw entry object before validation // loaded from docker labels or yaml file - Alias string `json:"-" yaml:"-"` - Scheme string `json:"scheme,omitempty" yaml:"scheme"` - Host string `json:"host,omitempty" yaml:"host"` - Port string `json:"port,omitempty" yaml:"port"` - NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only - PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only - HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"` - LoadBalance loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"` - Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"` - Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` + Alias string `json:"-" yaml:"-"` + Scheme string `json:"scheme,omitempty" yaml:"scheme"` + Host string `json:"host,omitempty" yaml:"host"` + Port string `json:"port,omitempty" yaml:"port"` + NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only + PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only + HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"` + LoadBalance *loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"` + Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"` + Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` /* Docker only */ Container *docker.Container `json:"container,omitempty" yaml:"-"` @@ -122,29 +122,41 @@ func (e *RawEntry) FillMissingFields() { } } - if e.HealthCheck.Interval == 0 { - e.HealthCheck.Interval = common.HealthCheckIntervalDefault + if e.HealthCheck == nil { + e.HealthCheck = new(health.HealthCheckConfig) } - if e.HealthCheck.Timeout == 0 { - e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault + + if e.HealthCheck.Disabled { + e.HealthCheck = nil + } else { + if e.HealthCheck.Interval == 0 { + e.HealthCheck.Interval = common.HealthCheckIntervalDefault + } + if e.HealthCheck.Timeout == 0 { + e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault + } } - if cont.IdleTimeout == "" { - cont.IdleTimeout = common.IdleTimeoutDefault - } - if cont.WakeTimeout == "" { - cont.WakeTimeout = common.WakeTimeoutDefault - } - if cont.StopTimeout == "" { - cont.StopTimeout = common.StopTimeoutDefault - } - if cont.StopMethod == "" { - cont.StopMethod = common.StopMethodDefault + + if cont.IdleTimeout != "" { + if cont.WakeTimeout == "" { + cont.WakeTimeout = common.WakeTimeoutDefault + } + if cont.StopTimeout == "" { + cont.StopTimeout = common.StopTimeoutDefault + } + if cont.StopMethod == "" { + cont.StopMethod = common.StopMethodDefault + } } e.Port = joinPorts(lp, pp, extra) if e.Port == "" || e.Host == "" { - e.Port = "0" + if lp != "" { + e.Port = lp + ":0" + } else { + e.Port = "0" + } } } diff --git a/internal/proxy/entry/reverse_proxy.go b/internal/proxy/entry/reverse_proxy.go new file mode 100644 index 0000000..95f4352 --- /dev/null +++ b/internal/proxy/entry/reverse_proxy.go @@ -0,0 +1,98 @@ +package entry + +import ( + "fmt" + "net/url" + + "github.com/yusing/go-proxy/internal/docker" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer" + net "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/proxy/fields" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +type ReverseProxyEntry struct { // real model after validation + Raw *RawEntry `json:"raw"` + + Alias fields.Alias `json:"alias,omitempty"` + Scheme fields.Scheme `json:"scheme,omitempty"` + URL net.URL `json:"url,omitempty"` + NoTLSVerify bool `json:"no_tls_verify,omitempty"` + PathPatterns fields.PathPatterns `json:"path_patterns,omitempty"` + HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` + LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"` + Middlewares docker.NestedLabelMap `json:"middlewares,omitempty"` + + /* Docker only */ + Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"` +} + +func (rp *ReverseProxyEntry) TargetName() string { + return string(rp.Alias) +} + +func (rp *ReverseProxyEntry) TargetURL() net.URL { + return rp.URL +} + +func (rp *ReverseProxyEntry) RawEntry() *RawEntry { + return rp.Raw +} + +func (rp *ReverseProxyEntry) LoadBalanceConfig() *loadbalancer.Config { + return rp.LoadBalance +} + +func (rp *ReverseProxyEntry) HealthCheckConfig() *health.HealthCheckConfig { + return rp.HealthCheck +} + +func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config { + return rp.Idlewatcher +} + +func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry { + cont := m.Container + if cont == nil { + cont = docker.DummyContainer + } + + lb := m.LoadBalance + if lb != nil && lb.Link == "" { + lb = nil + } + + host, err := fields.ValidateHost(m.Host) + b.Add(err) + + port, err := fields.ValidatePort(m.Port) + b.Add(err) + + pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns) + b.Add(err) + + url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) + b.Add(err) + + idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container) + b.Add(err) + + if err != nil { + return nil + } + + return &ReverseProxyEntry{ + Raw: m, + Alias: fields.NewAlias(m.Alias), + Scheme: s, + URL: net.NewURL(url), + NoTLSVerify: m.NoTLSVerify, + PathPatterns: pathPatterns, + HealthCheck: m.HealthCheck, + LoadBalance: lb, + Middlewares: m.Middlewares, + Idlewatcher: idleWatcherCfg, + } +} diff --git a/internal/proxy/entry/stream.go b/internal/proxy/entry/stream.go new file mode 100644 index 0000000..dd74de2 --- /dev/null +++ b/internal/proxy/entry/stream.go @@ -0,0 +1,89 @@ +package entry + +import ( + "fmt" + + "github.com/yusing/go-proxy/internal/docker" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer" + net "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/proxy/fields" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +type StreamEntry struct { + Raw *RawEntry `json:"raw"` + + Alias fields.Alias `json:"alias,omitempty"` + Scheme fields.StreamScheme `json:"scheme,omitempty"` + URL net.URL `json:"url,omitempty"` + Host fields.Host `json:"host,omitempty"` + Port fields.StreamPort `json:"port,omitempty"` + HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` + + /* Docker only */ + Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"` +} + +func (s *StreamEntry) TargetName() string { + return string(s.Alias) +} + +func (s *StreamEntry) TargetURL() net.URL { + return s.URL +} + +func (s *StreamEntry) RawEntry() *RawEntry { + return s.Raw +} + +func (s *StreamEntry) LoadBalanceConfig() *loadbalancer.Config { + // TODO: support stream load balance + return nil +} + +func (s *StreamEntry) HealthCheckConfig() *health.HealthCheckConfig { + return s.HealthCheck +} + +func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config { + return s.Idlewatcher +} + +func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry { + cont := m.Container + if cont == nil { + cont = docker.DummyContainer + } + + host, err := fields.ValidateHost(m.Host) + b.Add(err) + + port, err := fields.ValidateStreamPort(m.Port) + b.Add(err) + + scheme, err := fields.ValidateStreamScheme(m.Scheme) + b.Add(err) + + url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort))) + b.Add(err) + + idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container) + b.Add(err) + + if b.HasError() { + return nil + } + + return &StreamEntry{ + Raw: m, + Alias: fields.NewAlias(m.Alias), + Scheme: *scheme, + URL: url, + Host: host, + Port: port, + HealthCheck: m.HealthCheck, + Idlewatcher: idleWatcherCfg, + } +} diff --git a/internal/proxy/fields/signal.go b/internal/proxy/fields/signal.go deleted file mode 100644 index 0083b00..0000000 --- a/internal/proxy/fields/signal.go +++ /dev/null @@ -1,17 +0,0 @@ -package fields - -import ( - E "github.com/yusing/go-proxy/internal/error" -) - -type Signal string - -func ValidateSignal(s string) (Signal, E.NestedError) { - switch s { - case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", - "INT", "TERM", "HUP", "QUIT": - return Signal(s), nil - } - - return "", E.Invalid("signal", s) -} diff --git a/internal/proxy/fields/stop_method.go b/internal/proxy/fields/stop_method.go deleted file mode 100644 index bac9ad4..0000000 --- a/internal/proxy/fields/stop_method.go +++ /dev/null @@ -1,23 +0,0 @@ -package fields - -import ( - E "github.com/yusing/go-proxy/internal/error" -) - -type StopMethod string - -const ( - StopMethodPause StopMethod = "pause" - StopMethodStop StopMethod = "stop" - StopMethodKill StopMethod = "kill" -) - -func ValidateStopMethod(s string) (StopMethod, E.NestedError) { - sm := StopMethod(s) - switch sm { - case StopMethodPause, StopMethodStop, StopMethodKill: - return sm, nil - default: - return "", E.Invalid("stop_method", sm) - } -} diff --git a/internal/proxy/fields/timeout.go b/internal/proxy/fields/timeout.go deleted file mode 100644 index b299bea..0000000 --- a/internal/proxy/fields/timeout.go +++ /dev/null @@ -1,18 +0,0 @@ -package fields - -import ( - "time" - - E "github.com/yusing/go-proxy/internal/error" -) - -func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) { - d, err := time.ParseDuration(value) - if err != nil { - return 0, E.Invalid("duration", value) - } - if d < 0 { - return 0, E.Invalid("duration", "negative value") - } - return d, nil -} diff --git a/internal/route/http.go b/internal/route/http.go index 2e89ed4..19f57b0 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -9,23 +9,21 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/api/v1/errorpage" - "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/middleware" - url "github.com/yusing/go-proxy/internal/net/types" - P "github.com/yusing/go-proxy/internal/proxy" + "github.com/yusing/go-proxy/internal/proxy/entry" PT "github.com/yusing/go-proxy/internal/proxy/fields" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/task" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/health" ) type ( HTTPRoute struct { - *P.ReverseProxyEntry + *entry.ReverseProxyEntry HealthMon health.HealthMonitor `json:"health,omitempty"` @@ -33,6 +31,8 @@ type ( server *loadbalancer.Server handler http.Handler rp *gphttp.ReverseProxy + + task task.Task } SubdomainKey = PT.Alias @@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) { } } -func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { +func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) { var trans *http.Transport if entry.NoTLSVerify { @@ -84,12 +84,10 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { } } - httpRoutesMu.Lock() - defer httpRoutesMu.Unlock() - r := &HTTPRoute{ ReverseProxyEntry: entry, rp: rp, + task: task.DummyTask(), } return r, nil } @@ -98,39 +96,34 @@ func (r *HTTPRoute) String() string { return string(r.Alias) } -func (r *HTTPRoute) URL() url.URL { - return r.ReverseProxyEntry.URL -} - -func (r *HTTPRoute) Start() E.NestedError { - if r.ShouldNotServe() { +// Start implements task.TaskStarter. +func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError { + if entry.ShouldNotServe(r) { + providerSubtask.Finish("should not serve") return nil } httpRoutesMu.Lock() defer httpRoutesMu.Unlock() - if r.handler != nil { - return nil - } - - if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) { + if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) r.HealthCheck.Disabled = true } switch { - case r.UseIdleWatcher(): - watcher, err := idlewatcher.Register(r.ReverseProxyEntry) + case entry.UseIdleWatcher(r): + wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias)) + waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp) if err != nil { return err } - waker := idlewatcher.NewWaker(watcher, r.rp) r.handler = waker r.HealthMon = waker - case !r.HealthCheck.Disabled: - r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask(r.String()), r.URL(), r.HealthCheck) + case entry.UseHealthCheck(r): + r.HealthMon = health.NewHTTPHealthMonitor(r.TargetURL(), r.HealthCheck, r.rp.Transport) } + r.task = providerSubtask if r.handler == nil { switch { @@ -146,44 +139,26 @@ func (r *HTTPRoute) Start() E.NestedError { } if r.HealthMon != nil { - r.HealthMon.Start() + if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { + logrus.Warn(E.FailWith("health monitor", err)) + } } - if r.UseLoadBalance() { + if entry.UseLoadBalance(r) { r.addToLoadBalancer() } else { httpRoutes.Store(string(r.Alias), r) + r.task.OnComplete("stop rp", func() { + httpRoutes.Delete(string(r.Alias)) + }) } return nil } -func (r *HTTPRoute) Stop() (_ E.NestedError) { - if r.handler == nil { - return - } - - httpRoutesMu.Lock() - defer httpRoutesMu.Unlock() - - if r.loadBalancer != nil { - r.removeFromLoadBalancer() - } else { - httpRoutes.Delete(string(r.Alias)) - } - - if r.HealthMon != nil { - r.HealthMon.Stop() - r.HealthMon = nil - } - - r.handler = nil - - return -} - -func (r *HTTPRoute) Started() bool { - return r.handler != nil +// Finish implements task.TaskFinisher. +func (r *HTTPRoute) Finish(reason string) { + r.task.Finish(reason) } func (r *HTTPRoute) addToLoadBalancer() { @@ -197,10 +172,14 @@ func (r *HTTPRoute) addToLoadBalancer() { } } else { lb = loadbalancer.New(r.LoadBalance) - lb.Start() + lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link) + lbTask.OnComplete("remove lb from routes", func() { + httpRoutes.Delete(r.LoadBalance.Link) + }) + lb.Start(lbTask) linked = &HTTPRoute{ - ReverseProxyEntry: &P.ReverseProxyEntry{ - Raw: &types.RawEntry{ + ReverseProxyEntry: &entry.ReverseProxyEntry{ + Raw: &entry.RawEntry{ Homepage: r.Raw.Homepage, }, Alias: PT.Alias(lb.Link), @@ -214,16 +193,9 @@ func (r *HTTPRoute) addToLoadBalancer() { r.loadBalancer = lb r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) lb.AddServer(r.server) -} - -func (r *HTTPRoute) removeFromLoadBalancer() { - r.loadBalancer.RemoveServer(r.server) - if r.loadBalancer.IsEmpty() { - httpRoutes.Delete(r.LoadBalance.Link) - logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link) - } - r.server = nil - r.loadBalancer = nil + r.task.OnComplete("remove server from lb", func() { + lb.RemoveServer(r.server) + }) } func ProxyHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/proxy/provider/docker.go b/internal/route/provider/docker.go similarity index 64% rename from internal/proxy/provider/docker.go rename to internal/route/provider/docker.go index aa48151..9c480a9 100755 --- a/internal/proxy/provider/docker.go +++ b/internal/route/provider/docker.go @@ -10,10 +10,9 @@ import ( "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/proxy/entry" R "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/types" W "github.com/yusing/go-proxy/internal/watcher" - "github.com/yusing/go-proxy/internal/watcher/events" ) type DockerProvider struct { @@ -43,7 +42,7 @@ func (p *DockerProvider) NewWatcher() W.Watcher { func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { routes = R.NewRoutes() - entries := types.NewProxyEntries() + entries := entry.NewProxyEntries() info, err := D.GetClientInfo(p.dockerHost, true) if err != nil { @@ -66,12 +65,12 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { // there may be some valid entries in `en` dups := entries.MergeFrom(newEntries) // add the duplicate proxy entries to the error - dups.RangeAll(func(k string, v *types.RawEntry) { + dups.RangeAll(func(k string, v *entry.RawEntry) { errors.Addf("duplicate alias %s", k) }) } - entries.RangeAll(func(_ string, e *types.RawEntry) { + entries.RangeAll(func(_ string, e *entry.RawEntry) { e.Container.DockerHost = p.dockerHost }) @@ -88,85 +87,10 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool { strings.HasSuffix(container.ContainerName, "-old") } -func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) { - b := E.NewBuilder("event %s error", event) - defer b.To(&res.err) - - matches := R.NewRoutes() - oldRoutes.RangeAllParallel(func(k string, v *R.Route) { - if v.Entry.Container.ContainerID == event.ActorID || - v.Entry.Container.ContainerName == event.ActorName { - matches.Store(k, v) - } - }) - //FIXME: docker event die stuck - - var newRoutes R.Routes - var err E.NestedError - - switch { - // id & container name changed - case matches.Size() == 0: - matches = oldRoutes - newRoutes, err = p.LoadRoutesImpl() - b.Add(err) - case event.Action == events.ActionContainerDestroy: - // stop all old routes - matches.RangeAllParallel(func(_ string, v *R.Route) { - oldRoutes.Delete(v.Entry.Alias) - b.Add(v.Stop()) - res.nRemoved++ - }) - return - default: - cont, err := D.Inspect(p.dockerHost, event.ActorID) - if err != nil { - b.Add(E.FailWith("inspect container", err)) - return - } - - if p.shouldIgnore(cont) { - // stop all old routes - matches.RangeAllParallel(func(_ string, v *R.Route) { - b.Add(v.Stop()) - res.nRemoved++ - }) - return - } - - entries, err := p.entriesFromContainerLabels(cont) - b.Add(err) - newRoutes, err = R.FromEntries(entries) - b.Add(err) - } - - matches.RangeAll(func(k string, v *R.Route) { - if !newRoutes.Has(k) && !oldRoutes.Has(k) { - b.Add(v.Stop()) - matches.Delete(k) - res.nRemoved++ - } - }) - - newRoutes.RangeAll(func(alias string, newRoute *R.Route) { - oldRoute, exists := oldRoutes.Load(alias) - if exists { - b.Add(oldRoute.Stop()) - res.nReloaded++ - } else { - res.nAdded++ - } - b.Add(newRoute.Start()) - oldRoutes.Store(alias, newRoute) - }) - - return -} - // Returns a list of proxy entries for a container. // Always non-nil. -func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries types.RawEntries, _ E.NestedError) { - entries = types.NewProxyEntries() +func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) { + entries = entry.NewProxyEntries() if p.shouldIgnore(container) { return @@ -174,7 +98,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent // init entries map for all aliases for _, a := range container.Aliases { - entries.Store(a, &types.RawEntry{ + entries.Store(a, &entry.RawEntry{ Alias: a, Container: container, }) @@ -186,14 +110,14 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent } // remove all entries that failed to fill in missing fields - entries.RangeAll(func(_ string, re *types.RawEntry) { + entries.RangeAll(func(_ string, re *entry.RawEntry) { re.FillMissingFields() }) return entries, errors.Build().Subject(container.ContainerName) } -func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEntries, key, val string) (res E.NestedError) { +func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) { b := E.NewBuilder("errors in label %s", key) defer b.To(&res) @@ -220,7 +144,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt } if lbl.Target == D.WildcardAlias { // apply label for all aliases - entries.RangeAll(func(a string, e *types.RawEntry) { + entries.RangeAll(func(a string, e *entry.RawEntry) { if err = D.ApplyLabel(e, lbl); err != nil { b.Add(err) } diff --git a/internal/proxy/provider/docker_test.go b/internal/route/provider/docker_test.go similarity index 93% rename from internal/proxy/provider/docker_test.go rename to internal/route/provider/docker_test.go index 89959f0..43d3cc0 100644 --- a/internal/proxy/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -10,7 +10,7 @@ import ( "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" - P "github.com/yusing/go-proxy/internal/proxy" + "github.com/yusing/go-proxy/internal/proxy/entry" T "github.com/yusing/go-proxy/internal/proxy/fields" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -46,7 +46,7 @@ func TestApplyLabelWildcard(t *testing.T) { Names: dummyNames, Labels: map[string]string{ D.LabelAliases: "a,b", - D.LabelIdleTimeout: common.IdleTimeoutDefault, + D.LabelIdleTimeout: "", D.LabelStopMethod: common.StopMethodDefault, D.LabelStopSignal: "SIGTERM", D.LabelStopTimeout: common.StopTimeoutDefault, @@ -62,7 +62,7 @@ func TestApplyLabelWildcard(t *testing.T) { "proxy.a.middlewares.middleware2.prop3": "value3", "proxy.a.middlewares.middleware2.prop4": "value4", }, - }, "")) + }, client.DefaultDockerHost)) ExpectNoError(t, err.Error()) a, ok := entries.Load("a") @@ -88,8 +88,8 @@ func TestApplyLabelWildcard(t *testing.T) { ExpectDeepEqual(t, a.Middlewares, middlewaresExpect) ExpectEqual(t, len(b.Middlewares), 0) - ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault) - ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault) + ExpectEqual(t, a.Container.IdleTimeout, "") + ExpectEqual(t, b.Container.IdleTimeout, "") ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault) ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault) @@ -107,6 +107,7 @@ func TestApplyLabelWildcard(t *testing.T) { func TestApplyLabelWithAlias(t *testing.T) { entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, + State: "running", Labels: map[string]string{ D.LabelAliases: "a,b,c", "proxy.a.no_tls_verify": "true", @@ -114,7 +115,7 @@ func TestApplyLabelWithAlias(t *testing.T) { "proxy.b.port": "1234", "proxy.c.scheme": "https", }, - }, "")) + }, client.DefaultDockerHost)) a, ok := entries.Load("a") ExpectTrue(t, ok) b, ok := entries.Load("b") @@ -134,6 +135,7 @@ func TestApplyLabelWithAlias(t *testing.T) { func TestApplyLabelWithRef(t *testing.T) { entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, + State: "running", Labels: map[string]string{ D.LabelAliases: "a,b,c", "proxy.#1.host": "localhost", @@ -142,7 +144,7 @@ func TestApplyLabelWithRef(t *testing.T) { "proxy.#3.port": "1111", "proxy.#3.scheme": "https", }, - }, ""))) + }, client.DefaultDockerHost))) a, ok := entries.Load("a") ExpectTrue(t, ok) b, ok := entries.Load("b") @@ -161,6 +163,7 @@ func TestApplyLabelWithRef(t *testing.T) { func TestApplyLabelWithRefIndexError(t *testing.T) { c := D.FromDocker(&types.Container{ Names: dummyNames, + State: "running", Labels: map[string]string{ D.LabelAliases: "a,b", "proxy.#1.host": "localhost", @@ -173,6 +176,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { _, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, + State: "running", Labels: map[string]string{ D.LabelAliases: "a,b", "proxy.#0.host": "localhost", @@ -183,7 +187,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { } func TestPublicIPLocalhost(t *testing.T) { - c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost) + c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, client.DefaultDockerHost) raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1") @@ -191,7 +195,7 @@ func TestPublicIPLocalhost(t *testing.T) { } func TestPublicIPRemote(t *testing.T) { - c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375") + c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, "tcp://1.2.3.4:2375") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") @@ -218,6 +222,7 @@ func TestPrivateIPLocalhost(t *testing.T) { func TestPrivateIPRemote(t *testing.T) { c := D.FromDocker(&types.Container{ Names: dummyNames, + State: "running", NetworkSettings: &types.SummaryNetworkSettings{ Networks: map[string]*network.EndpointSettings{ "network": { @@ -239,6 +244,7 @@ func TestStreamDefaultValues(t *testing.T) { privIP := "172.17.0.123" cont := &types.Container{ Names: []string{"a"}, + State: "running", NetworkSettings: &types.SummaryNetworkSettings{ Networks: map[string]*network.EndpointSettings{ "network": { @@ -256,9 +262,8 @@ func TestStreamDefaultValues(t *testing.T) { raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - entry := Must(P.ValidateEntry(raw)) - - a := ExpectType[*P.StreamEntry](t, entry) + en := Must(entry.ValidateEntry(raw)) + a := ExpectType[*entry.StreamEntry](t, en) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) ExpectEqual(t, a.Host, T.Host(privIP)) @@ -270,9 +275,8 @@ func TestStreamDefaultValues(t *testing.T) { c := D.FromDocker(cont, "tcp://1.2.3.4:2375") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - entry := Must(P.ValidateEntry(raw)) - - a := ExpectType[*P.StreamEntry](t, entry) + en := Must(entry.ValidateEntry(raw)) + a := ExpectType[*entry.StreamEntry](t, en) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) ExpectEqual(t, a.Host, "1.2.3.4") diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go new file mode 100644 index 0000000..f10ba43 --- /dev/null +++ b/internal/route/provider/event_handler.go @@ -0,0 +1,109 @@ +package provider + +import ( + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/watcher" +) + +type EventHandler struct { + provider *Provider + + added []string + removed []string + paused []string + updated []string + errs E.Builder +} + +func (provider *Provider) newEventHandler() *EventHandler { + return &EventHandler{ + provider: provider, + errs: E.NewBuilder("event errors"), + } +} + +func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { + oldRoutes := handler.provider.routes + newRoutes, err := handler.provider.LoadRoutesImpl() + if err != nil { + handler.errs.Add(err.Subject("load routes")) + return + } + + oldRoutes.RangeAll(func(k string, v *route.Route) { + if !newRoutes.Has(k) { + handler.Remove(v) + } + }) + newRoutes.RangeAll(func(k string, newr *route.Route) { + if oldRoutes.Has(k) { + for _, ev := range events { + if handler.match(ev, newr) { + old, ok := oldRoutes.Load(k) + if !ok { // should not happen + panic("race condition") + } + handler.Update(parent, old, newr) + return + } + } + } else { + handler.Add(parent, newr) + } + }) +} + +func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { + switch handler.provider.t { + case ProviderTypeDocker: + return route.Entry.Container.ContainerID == event.ActorID || + route.Entry.Container.ContainerName == event.ActorName + case ProviderTypeFile: + return true + } + // should never happen + return false +} + +func (handler *EventHandler) Add(parent task.Task, route *route.Route) { + err := handler.provider.startRoute(parent, route) + if err != nil { + handler.errs.Add(err) + } else { + handler.added = append(handler.added, route.Entry.Alias) + } +} + +func (handler *EventHandler) Remove(route *route.Route) { + route.Finish("route removal") + handler.removed = append(handler.removed, route.Entry.Alias) +} + +func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) { + oldRoute.Finish("route update") + err := handler.provider.startRoute(parent, newRoute) + if err != nil { + handler.errs.Add(err) + } else { + handler.updated = append(handler.updated, newRoute.Entry.Alias) + } +} + +func (handler *EventHandler) Log() { + results := E.NewBuilder("event occured") + for _, alias := range handler.added { + results.Addf("added %s", alias) + } + for _, alias := range handler.removed { + results.Addf("removed %s", alias) + } + for _, alias := range handler.updated { + results.Addf("updated %s", alias) + } + results.Add(handler.errs.Build()) + if result := results.Build(); result != nil { + handler.provider.l.Info(result) + } +} diff --git a/internal/proxy/provider/file.go b/internal/route/provider/file.go similarity index 70% rename from internal/proxy/provider/file.go rename to internal/route/provider/file.go index 5a67ebc..5fbc96a 100644 --- a/internal/proxy/provider/file.go +++ b/internal/route/provider/file.go @@ -7,8 +7,8 @@ import ( "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/proxy/entry" R "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/types" U "github.com/yusing/go-proxy/internal/utils" W "github.com/yusing/go-proxy/internal/watcher" ) @@ -42,38 +42,13 @@ func (p FileProvider) String() string { return p.fileName } -func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) { - b := E.NewBuilder("event %s error", event) - defer b.To(&res.err) - - newRoutes, err := p.LoadRoutesImpl() - if err != nil { - b.Add(err) - return - } - - res.nRemoved = newRoutes.Size() - routes.RangeAllParallel(func(_ string, v *R.Route) { - b.Add(v.Stop()) - }) - routes.Clear() - - newRoutes.RangeAllParallel(func(_ string, v *R.Route) { - b.Add(v.Start()) - }) - res.nAdded = newRoutes.Size() - - routes.MergeFrom(newRoutes) - return -} - func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) { routes = R.NewRoutes() b := E.NewBuilder("file %q validation failure", p.fileName) defer b.To(&res) - entries := types.NewProxyEntries() + entries := entry.NewProxyEntries() data, err := E.Check(os.ReadFile(p.path)) if err != nil { diff --git a/internal/proxy/provider/provider.go b/internal/route/provider/provider.go similarity index 61% rename from internal/proxy/provider/provider.go rename to internal/route/provider/provider.go index 1318a76..45a2667 100644 --- a/internal/proxy/provider/provider.go +++ b/internal/route/provider/provider.go @@ -1,14 +1,16 @@ package provider import ( - "context" + "fmt" "path" + "time" "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" R "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/task" W "github.com/yusing/go-proxy/internal/watcher" + "github.com/yusing/go-proxy/internal/watcher/events" ) type ( @@ -19,18 +21,14 @@ type ( t ProviderType routes R.Routes - watcher W.Watcher - watcherTask common.Task - watcherCancel context.CancelFunc + watcher W.Watcher l *logrus.Entry } ProviderImpl interface { + fmt.Stringer NewWatcher() W.Watcher - // even returns error, routes must be non-nil LoadRoutesImpl() (R.Routes, E.NestedError) - OnEvent(event W.Event, routes R.Routes) EventResult - String() string } ProviderType string ProviderStats struct { @@ -38,17 +36,13 @@ type ( NumStreams int `json:"num_streams"` Type ProviderType `json:"type"` } - EventResult struct { - nAdded int - nRemoved int - nReloaded int - err E.NestedError - } ) const ( ProviderTypeDocker ProviderType = "docker" ProviderTypeFile ProviderType = "file" + + providerEventFlushInterval = 500 * time.Millisecond ) func newProvider(name string, t ProviderType) *Provider { @@ -106,32 +100,48 @@ func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } -func (p *Provider) StartAllRoutes() (res E.NestedError) { +func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError { + subtask := parent.Subtask("route %s", r.Entry.Alias) + err := r.Start(subtask) + if err != nil { + p.routes.Delete(r.Entry.Alias) + subtask.Finish(err.String()) // just to ensure + return err + } else { + subtask.OnComplete("del from provider", func() { + p.routes.Delete(r.Entry.Alias) + }) + } + return nil +} + +// Start implements task.TaskStarter. +func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) { errors := E.NewBuilder("errors starting routes") defer errors.To(&res) - // start watcher no matter load success or not - go p.watchEvents() + // routes and event queue will stop on parent cancel + providerTask := configSubtask p.routes.RangeAllParallel(func(alias string, r *R.Route) { - errors.Add(r.Start().Subject(r)) + errors.Add(p.startRoute(providerTask, r)) }) - return -} -func (p *Provider) StopAllRoutes() (res E.NestedError) { - if p.watcherCancel != nil { - p.watcherCancel() - p.watcherCancel = nil - } - - errors := E.NewBuilder("errors stopping routes") - defer errors.To(&res) - - p.routes.RangeAllParallel(func(alias string, r *R.Route) { - errors.Add(r.Stop().Subject(r)) - }) - p.routes.Clear() + eventQueue := events.NewEventQueue( + providerTask, + providerEventFlushInterval, + func(flushTask task.Task, events []events.Event) { + handler := p.newEventHandler() + // routes' lifetime should follow the provider's lifetime + handler.Handle(providerTask, events) + handler.Log() + flushTask.Finish("events flushed") + }, + func(err E.NestedError) { + p.l.Error(err) + }, + ) + eventQueue.Start(p.watcher.Events(providerTask.Context())) return } @@ -147,7 +157,6 @@ func (p *Provider) LoadRoutes() E.NestedError { var err E.NestedError p.routes, err = p.LoadRoutesImpl() if p.routes.Size() > 0 { - p.l.Infof("loaded %d routes", p.routes.Size()) return err } if err == nil { @@ -156,13 +165,14 @@ func (p *Provider) LoadRoutes() E.NestedError { return E.FailWith("loading routes", err) } +func (p *Provider) NumRoutes() int { + return p.routes.Size() +} + func (p *Provider) Statistics() ProviderStats { numRPs := 0 numStreams := 0 p.routes.RangeAll(func(_ string, r *R.Route) { - if !r.Started() { - return - } switch r.Type { case R.RouteTypeReverseProxy: numRPs++ @@ -176,34 +186,3 @@ func (p *Provider) Statistics() ProviderStats { Type: p.t, } } - -func (p *Provider) watchEvents() { - p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name) - defer p.watcherTask.Finished() - - events, errs := p.watcher.Events(p.watcherTask.Context()) - l := p.l.WithField("module", "watcher") - - for { - select { - case <-p.watcherTask.Context().Done(): - return - case event := <-events: - task := p.watcherTask.Subtask("%s event %s", event.Type, event) - l.Infof("%s event %q", event.Type, event) - res := p.OnEvent(event, p.routes) - task.Finished() - if res.nAdded+res.nRemoved+res.nReloaded > 0 { - l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded) - } - if res.err != nil { - l.Error(res.err) - } - case err := <-errs: - if err == nil || err.Is(context.Canceled) { - continue - } - l.Errorf("watcher error: %s", err) - } - } -} diff --git a/internal/route/route.go b/internal/route/route.go index 58d3782..6bcf114 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -4,8 +4,8 @@ import ( "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" url "github.com/yusing/go-proxy/internal/net/types" - P "github.com/yusing/go-proxy/internal/proxy" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/proxy/entry" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -16,16 +16,16 @@ type ( _ U.NoCopy impl Type RouteType - Entry *types.RawEntry + Entry *entry.RawEntry } Routes = F.Map[string, *Route] impl interface { - Start() E.NestedError - Stop() E.NestedError - Started() bool + entry.Entry + task.TaskStarter + task.TaskFinisher String() string - URL() url.URL + TargetURL() url.URL } ) @@ -44,8 +44,8 @@ func (rt *Route) Container() *docker.Container { return rt.Entry.Container } -func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { - entry, err := P.ValidateEntry(en) +func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) { + en, err := entry.ValidateEntry(raw) if err != nil { return nil, err } @@ -53,11 +53,11 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { var t RouteType var rt impl - switch e := entry.(type) { - case *P.StreamEntry: + switch e := en.(type) { + case *entry.StreamEntry: t = RouteTypeStream rt, err = NewStreamRoute(e) - case *P.ReverseProxyEntry: + case *entry.ReverseProxyEntry: t = RouteTypeReverseProxy rt, err = NewHTTPRoute(e) default: @@ -69,19 +69,21 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { return &Route{ impl: rt, Type: t, - Entry: en, + Entry: raw, }, nil } -func FromEntries(entries types.RawEntries) (Routes, E.NestedError) { +func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) { b := E.NewBuilder("errors in routes") routes := NewRoutes() - entries.RangeAll(func(alias string, entry *types.RawEntry) { - entry.Alias = alias - r, err := NewRoute(entry) + entries.RangeAllParallel(func(alias string, en *entry.RawEntry) { + en.Alias = alias + r, err := NewRoute(en) if err != nil { b.Add(err.Subject(alias)) + } else if entry.ShouldNotServe(r) { + return } else { routes.Store(alias, r) } diff --git a/internal/route/stream.go b/internal/route/stream.go index 1d9c38b..03b268c 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -4,169 +4,141 @@ import ( "context" "errors" "fmt" - "net" + stdNet "net" "sync" "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" - url "github.com/yusing/go-proxy/internal/net/types" - P "github.com/yusing/go-proxy/internal/proxy" - PT "github.com/yusing/go-proxy/internal/proxy/fields" + net "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/proxy/entry" + "github.com/yusing/go-proxy/internal/task" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/health" ) type StreamRoute struct { - *P.StreamEntry - StreamImpl `json:"-"` + *entry.StreamEntry + net.Stream `json:"-"` HealthMon health.HealthMonitor `json:"health"` - url url.URL - - task common.Task - cancel context.CancelFunc - done chan struct{} + task task.Task l logrus.FieldLogger - - mu sync.Mutex } -type StreamImpl interface { - Setup() error - Accept() (any, error) - Handle(conn any) error - CloseListeners() - String() string -} - -var streamRoutes = F.NewMapOf[string, *StreamRoute]() +var ( + streamRoutes = F.NewMapOf[string, *StreamRoute]() + streamRoutesMu sync.Mutex +) func GetStreamProxies() F.Map[string, *StreamRoute] { return streamRoutes } -func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { +func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) { // TODO: support non-coherent scheme if !entry.Scheme.IsCoherent() { return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) } - url, err := url.ParseURL(fmt.Sprintf("%s://%s:%d", entry.Scheme.ProxyScheme, entry.Host, entry.Port.ProxyPort)) - if err != nil { - // !! should not happen - panic(err) - } - base := &StreamRoute{ + return &StreamRoute{ StreamEntry: entry, - url: url, - } - if entry.Scheme.ListeningScheme.IsTCP() { - base.StreamImpl = NewTCPRoute(base) - } else { - base.StreamImpl = NewUDPRoute(base) - } - base.l = logrus.WithField("route", base.StreamImpl) - return base, nil + task: task.DummyTask(), + }, nil +} + +func (r *StreamRoute) Finish(reason string) { + r.task.Finish(reason) } func (r *StreamRoute) String() string { return fmt.Sprintf("stream %s", r.Alias) } -func (r *StreamRoute) URL() url.URL { - return r.url -} - -func (r *StreamRoute) Start() E.NestedError { - r.mu.Lock() - defer r.mu.Unlock() - - if r.Port.ProxyPort == PT.NoPort || r.task != nil { +// Start implements task.TaskStarter. +func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError { + if entry.ShouldNotServe(r) { + providerSubtask.Finish("should not serve") return nil } - r.task, r.cancel = common.NewTaskWithCancel(r.String()) + + streamRoutesMu.Lock() + defer streamRoutesMu.Unlock() + + if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { + logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) + r.HealthCheck.Disabled = true + } + + if r.Scheme.ListeningScheme.IsTCP() { + r.Stream = NewTCPRoute(r) + } else { + r.Stream = NewUDPRoute(r) + } + r.l = logrus.WithField("route", r.Stream.String()) + + switch { + case entry.UseIdleWatcher(r): + wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias)) + waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream) + if err != nil { + return err + } + r.Stream = waker + r.HealthMon = waker + case entry.UseHealthCheck(r): + r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck) + } + r.task = providerSubtask + r.task.OnComplete("stop stream", r.CloseListeners) + if err := r.Setup(); err != nil { return E.FailWith("setup", err) } - r.done = make(chan struct{}) r.l.Infof("listening on port %d", r.Port.ListeningPort) + go r.acceptConnections() - if !r.Healthcheck.Disabled { - r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck) - r.HealthMon.Start() + + if r.HealthMon != nil { + r.HealthMon.Start(r.task.Subtask("health monitor")) } streamRoutes.Store(string(r.Alias), r) return nil } -func (r *StreamRoute) Stop() E.NestedError { - r.mu.Lock() - defer r.mu.Unlock() - - if r.task == nil { - return nil - } - - streamRoutes.Delete(string(r.Alias)) - - if r.HealthMon != nil { - r.HealthMon.Stop() - r.HealthMon = nil - } - - r.cancel() - r.CloseListeners() - - <-r.done - return nil -} - -func (r *StreamRoute) Started() bool { - return r.task != nil -} - func (r *StreamRoute) acceptConnections() { - var connWg sync.WaitGroup - - task := r.task.Subtask("%s accept connections", r.String()) - - defer func() { - connWg.Wait() - task.Finished() - r.task.Finished() - r.task, r.cancel = nil, nil - close(r.done) - r.done = nil - }() - for { select { - case <-task.Context().Done(): + case <-r.task.Context().Done(): return default: conn, err := r.Accept() if err != nil { select { - case <-task.Context().Done(): + case <-r.task.Context().Done(): return default: - var nErr *net.OpError + var nErr *stdNet.OpError ok := errors.As(err, &nErr) if !(ok && nErr.Timeout()) { - r.l.Error(err) + r.l.Error("accept connection error: ", err) + r.task.Finish(err.Error()) + return } continue } } - connWg.Add(1) + connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String()) go func() { err := r.Handle(conn) if err != nil && !errors.Is(err, context.Canceled) { r.l.Error(err) + connTask.Finish(err.Error()) + } else { + connTask.Finish("connection closed") } - connWg.Done() + conn.Close() }() } } diff --git a/internal/route/tcp.go b/internal/route/tcp.go index e076a76..20d378b 100755 --- a/internal/route/tcp.go +++ b/internal/route/tcp.go @@ -6,6 +6,7 @@ import ( "net" "time" + "github.com/yusing/go-proxy/internal/net/types" T "github.com/yusing/go-proxy/internal/proxy/fields" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -21,7 +22,7 @@ type ( } ) -func NewTCPRoute(base *StreamRoute) StreamImpl { +func NewTCPRoute(base *StreamRoute) *TCPRoute { return &TCPRoute{StreamRoute: base} } @@ -36,19 +37,16 @@ func (route *TCPRoute) Setup() error { return nil } -func (route *TCPRoute) Accept() (any, error) { +func (route *TCPRoute) Accept() (types.StreamConn, error) { route.listener.SetDeadline(time.Now().Add(time.Second)) return route.listener.Accept() } -func (route *TCPRoute) Handle(c any) error { +func (route *TCPRoute) Handle(c types.StreamConn) error { clientConn := c.(net.Conn) defer clientConn.Close() - go func() { - <-route.task.Context().Done() - clientConn.Close() - }() + route.task.OnComplete("close conn", func() { clientConn.Close() }) ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) @@ -70,5 +68,4 @@ func (route *TCPRoute) CloseListeners() { return } route.listener.Close() - route.listener = nil } diff --git a/internal/route/udp.go b/internal/route/udp.go index faadff6..2f19f96 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -1,11 +1,13 @@ package route import ( + "errors" "fmt" "io" "net" "time" + "github.com/yusing/go-proxy/internal/net/types" T "github.com/yusing/go-proxy/internal/proxy/fields" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -33,7 +35,7 @@ var NewUDPConnMap = F.NewMap[UDPConnMap] const udpBufferSize = 8192 -func NewUDPRoute(base *StreamRoute) StreamImpl { +func NewUDPRoute(base *StreamRoute) *UDPRoute { return &UDPRoute{ StreamRoute: base, connMap: NewUDPConnMap(), @@ -64,7 +66,7 @@ func (route *UDPRoute) Setup() error { return nil } -func (route *UDPRoute) Accept() (any, error) { +func (route *UDPRoute) Accept() (types.StreamConn, error) { in := route.listeningConn buffer := make([]byte, udpBufferSize) @@ -104,7 +106,7 @@ func (route *UDPRoute) Accept() (any, error) { return conn, err } -func (route *UDPRoute) Handle(c any) error { +func (route *UDPRoute) Handle(c types.StreamConn) error { conn := c.(*UDPConn) err := conn.Start() route.connMap.Delete(conn.key) @@ -114,19 +116,25 @@ func (route *UDPRoute) Handle(c any) error { func (route *UDPRoute) CloseListeners() { if route.listeningConn != nil { route.listeningConn.Close() - route.listeningConn = nil } route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) { - if err := conn.src.Close(); err != nil { - route.l.Errorf("error closing src conn: %s", err) - } - if err := conn.dst.Close(); err != nil { - route.l.Error("error closing dst conn: %s", err) + if err := conn.Close(); err != nil { + route.l.Errorf("error closing conn: %s", err) } }) route.connMap.Clear() } +// Close implements types.StreamConn +func (conn *UDPConn) Close() error { + return errors.Join(conn.src.Close(), conn.dst.Close()) +} + +// RemoteAddr implements types.StreamConn +func (conn *UDPConn) RemoteAddr() net.Addr { + return conn.src.RemoteAddr() +} + type sourceRWCloser struct { server *net.UDPConn *net.UDPConn diff --git a/internal/server/server.go b/internal/server/server.go index 7bf2120..b6784e9 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/tls" "errors" "log" @@ -9,8 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/autocert" - "github.com/yusing/go-proxy/internal/common" - "golang.org/x/net/context" + "github.com/yusing/go-proxy/internal/task" ) type Server struct { @@ -21,7 +21,8 @@ type Server struct { httpStarted bool httpsStarted bool startTime time.Time - task common.Task + + task task.Task } type Options struct { @@ -84,7 +85,7 @@ func NewServer(opt Options) (s *Server) { CertProvider: opt.CertProvider, http: httpSer, https: httpsSer, - task: common.GlobalTask(opt.Name + " server"), + task: task.GlobalTask(opt.Name + " server"), } } @@ -115,11 +116,7 @@ func (s *Server) Start() { }() } - go func() { - <-s.task.Context().Done() - s.stop() - s.task.Finished() - }() + s.task.OnComplete("stop server", s.stop) } func (s *Server) stop() { @@ -127,16 +124,13 @@ func (s *Server) stop() { return } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - if s.http != nil && s.httpStarted { - s.handleErr("http", s.http.Shutdown(ctx)) + s.handleErr("http", s.http.Shutdown(s.task.Context())) s.httpStarted = false } if s.https != nil && s.httpsStarted { - s.handleErr("https", s.https.Shutdown(ctx)) + s.handleErr("https", s.https.Shutdown(s.task.Context())) s.httpsStarted = false } } @@ -147,7 +141,7 @@ func (s *Server) Uptime() time.Duration { func (s *Server) handleErr(scheme string, err error) { switch { - case err == nil, errors.Is(err, http.ErrServerClosed): + case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): return default: logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err) diff --git a/internal/task/dummy_task.go b/internal/task/dummy_task.go new file mode 100644 index 0000000..51eca11 --- /dev/null +++ b/internal/task/dummy_task.go @@ -0,0 +1,43 @@ +package task + +import "context" + +type dummyTask struct{} + +func DummyTask() (_ Task) { + return +} + +// Context implements Task. +func (d dummyTask) Context() context.Context { + panic("call of dummyTask.Context") +} + +// Finish implements Task. +func (d dummyTask) Finish() {} + +// Name implements Task. +func (d dummyTask) Name() string { + return "Dummy Task" +} + +// OnComplete implements Task. +func (d dummyTask) OnComplete(about string, fn func()) { + panic("call of dummyTask.OnComplete") +} + +// Parent implements Task. +func (d dummyTask) Parent() Task { + panic("call of dummyTask.Parent") +} + +// Subtask implements Task. +func (d dummyTask) Subtask(usageFmt string, args ...any) Task { + panic("call of dummyTask.Subtask") +} + +// Wait implements Task. +func (d dummyTask) Wait() {} + +// WaitSubTasks implements Task. +func (d dummyTask) WaitSubTasks() {} diff --git a/internal/task/task.go b/internal/task/task.go new file mode 100644 index 0000000..742f898 --- /dev/null +++ b/internal/task/task.go @@ -0,0 +1,310 @@ +package task + +import ( + "context" + "errors" + "fmt" + "runtime" + "strings" + "sync" + "time" + + "github.com/puzpuzpuz/xsync/v3" + "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" +) + +var globalTask = createGlobalTask() + +func createGlobalTask() (t *task) { + t = new(task) + t.name = "root" + t.ctx, t.cancel = context.WithCancelCause(context.Background()) + t.subtasks = xsync.NewMapOf[*task, struct{}]() + return +} + +type ( + // Task controls objects' lifetime. + // + // Task must be initialized, use DummyTask if the task is not yet started. + // + // Objects that uses a task should implement the TaskStarter and the TaskFinisher interface. + // + // When passing a Task object to another function, + // it must be a sub-task of the current task, + // in name of "`currentTaskName`Subtask" + // + // Use Task.Finish to stop all subtasks of the task. + Task interface { + TaskFinisher + // Name returns the name of the task. + Name() string + // Context returns the context associated with the task. This context is + // canceled when Finish is called. + Context() context.Context + // FinishCause returns the reason / error that caused the task to be finished. + FinishCause() error + // Parent returns the parent task of the current task. + Parent() Task + // Subtask returns a new subtask with the given name, derived from the parent's context. + // + // If the parent's context is already canceled, the returned subtask will be canceled immediately. + // + // This should not be called after Finish, Wait, or WaitSubTasks is called. + Subtask(usageFmt string, args ...any) Task + // OnComplete calls fn when the task and all subtasks are finished. + // + // It cannot be called after Finish or Wait is called. + OnComplete(about string, fn func()) + // Wait waits for all subtasks, itself and all OnComplete to finish. + // + // It must be called only after Finish is called. + Wait() + // WaitSubTasks waits for all subtasks of the task to finish. + // + // No more subtasks can be added after this call. + // + // It can be called before Finish is called. + WaitSubTasks() + } + TaskStarter interface { + // Start starts the object that implements TaskStarter, + // and returns an error if it fails to start. + // + // The task passed must be a subtask of the caller task. + // + // callerSubtask.Finish must be called when start fails or the object is finished. + Start(callerSubtask Task) E.NestedError + } + TaskFinisher interface { + // Finish marks the task as finished by cancelling its context. + // + // Then call Wait to wait for all subtasks and OnComplete of the task to finish. + // + // Note that it will also cancel all subtasks. + Finish(reason string) + } + task struct { + ctx context.Context + cancel context.CancelCauseFunc + + parent *task + subtasks *xsync.MapOf[*task, struct{}] + + name, line string + + subTasksWg, onCompleteWg sync.WaitGroup + } +) + +var ( + ErrProgramExiting = errors.New("program exiting") + ErrTaskCancelled = errors.New("task cancelled") +) + +// GlobalTask returns a new Task with the given name, derived from the global context. +func GlobalTask(format string, args ...any) Task { + return globalTask.Subtask(format, args...) +} + +// DebugTaskMap returns a map[string]any representation of the global task tree. +// +// The returned map is suitable for encoding to JSON, and can be used +// to debug the task tree. +// +// The returned map is not guaranteed to be stable, and may change +// between runs of the program. It is intended for debugging purposes +// only. +func DebugTaskMap() map[string]any { + return globalTask.serialize() +} + +// CancelGlobalContext cancels the global task context, which will cause all tasks +// created to be canceled. This should be called before exiting the program +// to ensure that all tasks are properly cleaned up. +func CancelGlobalContext() { + globalTask.cancel(ErrProgramExiting) +} + +// GlobalContextWait waits for all tasks to finish, up to the given timeout. +// +// If the timeout is exceeded, it prints a list of all tasks that were +// still running when the timeout was reached, and their current tree +// of subtasks. +func GlobalContextWait(timeout time.Duration) { + done := make(chan struct{}) + after := time.After(timeout) + go func() { + globalTask.Wait() + close(done) + }() + for { + select { + case <-done: + return + case <-after: + logrus.Warn("Timeout waiting for these tasks to finish:\n" + globalTask.tree()) + return + } + } +} + +func (t *task) Name() string { + return t.name +} + +func (t *task) Context() context.Context { + return t.ctx +} + +func (t *task) FinishCause() error { + return context.Cause(t.ctx) +} + +func (t *task) Parent() Task { + return t.parent +} + +func (t *task) OnComplete(about string, fn func()) { + t.onCompleteWg.Add(1) + var file string + var line int + if common.IsTrace { + _, file, line, _ = runtime.Caller(1) + } + go func() { + defer func() { + if err := recover(); err != nil { + logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err) + } + }() + defer t.onCompleteWg.Done() + t.subTasksWg.Wait() + <-t.ctx.Done() + fn() + logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about) + t.cancel(nil) // ensure resources are released + }() +} + +func (t *task) Finish(reason string) { + t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason)) + t.Wait() +} + +func (t *task) Subtask(format string, args ...any) Task { + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + ctx, cancel := context.WithCancelCause(t.ctx) + return t.newSubTask(ctx, cancel, format) +} + +func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task { + parent := t + subtask := &task{ + ctx: ctx, + cancel: cancel, + name: name, + parent: parent, + subtasks: xsync.NewMapOf[*task, struct{}](), + } + parent.subTasksWg.Add(1) + parent.subtasks.Store(subtask, struct{}{}) + if common.IsTrace { + _, file, line, ok := runtime.Caller(3) + if ok { + subtask.line = fmt.Sprintf("%s:%d", file, line) + } + logrus.Tracef("line %s\ntask %q started", subtask.line, name) + go func() { + subtask.Wait() + logrus.Tracef("task %q finished", subtask.Name()) + }() + } + go func() { + subtask.Wait() + parent.subtasks.Delete(subtask) + parent.subTasksWg.Done() + }() + return subtask +} + +func (t *task) Wait() { + t.subTasksWg.Wait() + if t != globalTask { + <-t.ctx.Done() + } + t.onCompleteWg.Wait() +} + +func (t *task) WaitSubTasks() { + t.subTasksWg.Wait() +} + +// tree returns a string representation of the task tree, with the given +// prefix prepended to each line. The prefix is used to indent the tree, +// and should be a string of spaces or a similar separator. +// +// The resulting string is suitable for printing to the console, and can be +// used to debug the task tree. +// +// The tree is traversed in a depth-first manner, with each task's name and +// line number (if available) printed on a separate line. The line number is +// only printed if the task was created with a non-empty line argument. +// +// The returned string is not guaranteed to be stable, and may change between +// runs of the program. It is intended for debugging purposes only. +func (t *task) tree(prefix ...string) string { + var sb strings.Builder + var pre string + if len(prefix) > 0 { + pre = prefix[0] + sb.WriteString(pre + "- ") + } + if t.line != "" { + sb.WriteString("line " + t.line + "\n") + } + if len(pre) > 0 { + sb.WriteString(pre + "- ") + } + sb.WriteString(t.Name() + "\n") + t.subtasks.Range(func(subtask *task, _ struct{}) bool { + sb.WriteString(subtask.tree(pre + " ")) + return true + }) + return sb.String() +} + +// serialize returns a map[string]any representation of the task tree. +// +// The map contains the following keys: +// - name: the name of the task +// - line: the line number of the task, if available +// - subtasks: a slice of maps, each representing a subtask +// +// The subtask maps contain the same keys, recursively. +// +// The returned map is suitable for encoding to JSON, and can be used +// to debug the task tree. +// +// The returned map is not guaranteed to be stable, and may change +// between runs of the program. It is intended for debugging purposes +// only. +func (t *task) serialize() map[string]any { + m := make(map[string]any) + m["name"] = t.name + if t.line != "" { + m["line"] = t.line + } + if t.subtasks.Size() > 0 { + m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size()) + t.subtasks.Range(func(subtask *task, _ struct{}) bool { + m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize()) + return true + }) + } + return m +} diff --git a/internal/task/task_test.go b/internal/task/task_test.go new file mode 100644 index 0000000..1fc08e0 --- /dev/null +++ b/internal/task/task_test.go @@ -0,0 +1,147 @@ +package task_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + . "github.com/yusing/go-proxy/internal/task" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestTaskCreation(t *testing.T) { + defer CancelGlobalContext() + + rootTask := GlobalTask("root-task") + subTask := rootTask.Subtask("subtask") + + ExpectEqual(t, "root-task", rootTask.Name()) + ExpectEqual(t, "subtask", subTask.Name()) +} + +func TestTaskCancellation(t *testing.T) { + defer CancelGlobalContext() + + subTaskDone := make(chan struct{}) + + rootTask := GlobalTask("root-task") + subTask := rootTask.Subtask("subtask") + + go func() { + subTask.Wait() + close(subTaskDone) + }() + + go rootTask.Finish("done") + + select { + case <-subTaskDone: + err := subTask.Context().Err() + ExpectError(t, context.Canceled, err) + cause := context.Cause(subTask.Context()) + ExpectError(t, ErrTaskCancelled, cause) + case <-time.After(1 * time.Second): + t.Fatal("subTask context was not canceled as expected") + } +} + +func TestGlobalContextCancellation(t *testing.T) { + taskDone := make(chan struct{}) + + rootTask := GlobalTask("root-task") + + go func() { + rootTask.Wait() + close(taskDone) + }() + + CancelGlobalContext() + + select { + case <-taskDone: + err := rootTask.Context().Err() + ExpectError(t, context.Canceled, err) + cause := context.Cause(rootTask.Context()) + ExpectError(t, ErrProgramExiting, cause) + case <-time.After(1 * time.Second): + t.Fatal("subTask context was not canceled as expected") + } +} + +func TestOnComplete(t *testing.T) { + defer CancelGlobalContext() + + task := GlobalTask("test") + var value atomic.Int32 + task.OnComplete("set value", func() { + value.Store(1234) + }) + task.Finish("done") + ExpectEqual(t, value.Load(), 1234) +} + +func TestGlobalContextWait(t *testing.T) { + defer CancelGlobalContext() + + rootTask := GlobalTask("root-task") + + finished1, finished2 := false, false + + subTask1 := rootTask.Subtask("subtask1") + subTask2 := rootTask.Subtask("subtask2") + subTask1.OnComplete("set finished", func() { + finished1 = true + }) + subTask2.OnComplete("set finished", func() { + finished2 = true + }) + + go func() { + time.Sleep(500 * time.Millisecond) + subTask1.Finish("done") + }() + + go func() { + time.Sleep(500 * time.Millisecond) + subTask2.Finish("done") + }() + + go func() { + subTask1.Wait() + subTask2.Wait() + rootTask.Finish("done") + }() + + GlobalContextWait(1 * time.Second) + ExpectTrue(t, finished1) + ExpectTrue(t, finished2) + ExpectError(t, context.Canceled, rootTask.Context().Err()) + ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context())) + ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context())) +} + +func TestTimeoutOnGlobalContextWait(t *testing.T) { + defer CancelGlobalContext() + + rootTask := GlobalTask("root-task") + subTask := rootTask.Subtask("subtask") + + done := make(chan struct{}) + go func() { + GlobalContextWait(500 * time.Millisecond) + close(done) + }() + + select { + case <-done: + t.Fatal("GlobalContextWait should have timed out") + case <-time.After(200 * time.Millisecond): + } + + // Ensure clean exit + subTask.Finish("exit") +} + +func TestGlobalContextCancel(t *testing.T) { +} diff --git a/internal/types/config.go b/internal/types/config.go deleted file mode 100644 index bbbc21f..0000000 --- a/internal/types/config.go +++ /dev/null @@ -1,18 +0,0 @@ -package types - -type Config struct { - Providers ProxyProviders `json:"providers" yaml:",flow"` - AutoCert AutoCertConfig `json:"autocert" yaml:",flow"` - ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"` - MatchDomains []string `json:"match_domains" yaml:"match_domains"` - TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"` - RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"` -} - -func DefaultConfig() *Config { - return &Config{ - Providers: ProxyProviders{}, - TimeoutShutdown: 3, - RedirectToHTTPS: false, - } -} diff --git a/internal/types/proxy_providers.go b/internal/types/proxy_providers.go deleted file mode 100644 index 7ba4efa..0000000 --- a/internal/types/proxy_providers.go +++ /dev/null @@ -1,6 +0,0 @@ -package types - -type ProxyProviders struct { - Files []string `json:"include" yaml:"include"` // docker, file - Docker map[string]string `json:"docker" yaml:"docker"` -} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index ebfcf0f..d74b518 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -23,7 +23,7 @@ func IgnoreError[Result any](r Result, _ error) Result { func ExpectNoError(t *testing.T, err error) { t.Helper() if err != nil && !reflect.ValueOf(err).IsNil() { - t.Errorf("expected err=nil, got %s", err.Error()) + t.Errorf("expected err=nil, got %s", err) t.FailNow() } } @@ -31,7 +31,7 @@ func ExpectNoError(t *testing.T, err error) { func ExpectError(t *testing.T, expected error, err error) { t.Helper() if !errors.Is(err, expected) { - t.Errorf("expected err %s, got %s", expected.Error(), err.Error()) + t.Errorf("expected err %s, got %s", expected, err) t.FailNow() } } @@ -39,7 +39,7 @@ func ExpectError(t *testing.T, expected error, err error) { func ExpectError2(t *testing.T, input any, expected error, err error) { t.Helper() if !errors.Is(err, expected) { - t.Errorf("%v: expected err %s, got %s", input, expected.Error(), err.Error()) + t.Errorf("%v: expected err %s, got %s", input, expected, err) t.FailNow() } } diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index 57369c5..6446860 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -15,8 +15,9 @@ import ( type ( DockerWatcher struct { - host string - client D.Client + host string + client D.Client + clientOwned bool logrus.FieldLogger } DockerListOptions = docker_events.ListOptions @@ -44,10 +45,11 @@ func DockerrFilterContainer(nameOrID string) filters.KeyValuePair { func NewDockerWatcher(host string) DockerWatcher { return DockerWatcher{ + host: host, + clientOwned: true, FieldLogger: (logrus. WithField("module", "docker_watcher"). WithField("host", host)), - host: host, } } @@ -72,7 +74,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList defer close(errCh) defer func() { - if w.client.Connected() { + if w.clientOwned && w.client.Connected() { w.client.Close() } }() diff --git a/internal/watcher/events/event_queue.go b/internal/watcher/events/event_queue.go new file mode 100644 index 0000000..d8770c3 --- /dev/null +++ b/internal/watcher/events/event_queue.go @@ -0,0 +1,91 @@ +package events + +import ( + "time" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/task" +) + +type ( + EventQueue struct { + task task.Task + queue []Event + ticker *time.Ticker + onFlush OnFlushFunc + onError OnErrorFunc + } + OnFlushFunc = func(flushTask task.Task, events []Event) + OnErrorFunc = func(err E.NestedError) +) + +const eventQueueCapacity = 10 + +// NewEventQueue returns a new EventQueue with the given +// queueTask, flushInterval, onFlush and onError. +// +// The returned EventQueue will start a goroutine to flush events in the queue +// when the flushInterval is reached. +// +// The onFlush function is called when the flushInterval is reached and the queue is not empty, +// +// The onError function is called when an error received from the errCh, +// or panic occurs in the onFlush function. Panic will cause a E.ErrPanicRecv error. +// +// flushTask.Finish must be called after the flush is done, +// but the onFlush function can return earlier (e.g. run in another goroutine). +// +// If task is cancelled before the flushInterval is reached, the events in queue will be discarded. +func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { + return &EventQueue{ + task: parent.Subtask("event queue"), + queue: make([]Event, 0, eventQueueCapacity), + ticker: time.NewTicker(flushInterval), + onFlush: onFlush, + onError: onError, + } +} + +func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) { + go func() { + defer e.ticker.Stop() + for { + select { + case <-e.task.Context().Done(): + e.task.Finish(e.task.FinishCause().Error()) + return + case <-e.ticker.C: + if len(e.queue) > 0 { + flushTask := e.task.Subtask("flush events") + queue := e.queue + e.queue = make([]Event, 0, eventQueueCapacity) + go func() { + defer func() { + if err := recover(); err != nil { + e.onError(E.PanicRecv("panic in onFlush %s", err)) + } + }() + e.onFlush(flushTask, queue) + }() + flushTask.Wait() + } + case event, ok := <-eventCh: + e.queue = append(e.queue, event) + if !ok { + return + } + case err := <-errCh: + if err != nil { + e.onError(err) + } + } + } + }() +} + +// Wait waits for all events to be flushed and the task to finish. +// +// It is safe to call this method multiple times. +func (e *EventQueue) Wait() { + e.task.Wait() +} diff --git a/internal/watcher/events/events.go b/internal/watcher/events/events.go index 876f77d..069c376 100644 --- a/internal/watcher/events/events.go +++ b/internal/watcher/events/events.go @@ -74,7 +74,7 @@ var actionNameMap = func() (m map[Action]string) { }() func (e Event) String() string { - return fmt.Sprintf("%s %s", e.ActorName, e.Action) + return fmt.Sprintf("%s %s", e.Action, e.ActorName) } func (a Action) String() string { diff --git a/internal/watcher/health/http.go b/internal/watcher/health/http.go index 2c490e8..b3b92ee 100644 --- a/internal/watcher/health/http.go +++ b/internal/watcher/health/http.go @@ -5,7 +5,6 @@ import ( "errors" "net/http" - "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/net/types" ) @@ -15,10 +14,10 @@ type HTTPHealthMonitor struct { pinger *http.Client } -func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor { +func NewHTTPHealthMonitor(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) *HTTPHealthMonitor { mon := new(HTTPHealthMonitor) - mon.monitor = newMonitor(task, url, config, mon.checkHealth) - mon.pinger = &http.Client{Timeout: config.Timeout} + mon.monitor = newMonitor(url, config, mon.CheckHealth) + mon.pinger = &http.Client{Timeout: config.Timeout, Transport: transport} if config.UseGet { mon.method = http.MethodGet } else { @@ -27,19 +26,26 @@ func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckCo return mon } -func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) { +func NewHTTPHealthChecker(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) HealthChecker { + return NewHTTPHealthMonitor(url, config, transport) +} + +func (mon *HTTPHealthMonitor) CheckHealth() (healthy bool, detail string, err error) { + ctx, cancel := mon.ContextWithTimeout("ping request timed out") + defer cancel() + req, reqErr := http.NewRequestWithContext( - mon.task.Context(), + ctx, mon.method, - mon.url.JoinPath(mon.config.Path).String(), + mon.url.Load().JoinPath(mon.config.Path).String(), nil, ) if reqErr != nil { err = reqErr return } - req.Header.Set("Connection", "close") + req.Header.Set("Connection", "close") resp, respErr := mon.pinger.Do(req) if respErr == nil { resp.Body.Close() diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go index 71b1db3..1890b6c 100644 --- a/internal/watcher/health/monitor.go +++ b/internal/watcher/health/monitor.go @@ -2,78 +2,93 @@ package health import ( "context" + "encoding/json" "errors" - "sync" + "fmt" "time" - "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) type ( HealthMonitor interface { - Start() - Stop() + task.TaskStarter + task.TaskFinisher + fmt.Stringer + json.Marshaler Status() Status Uptime() time.Duration Name() string - String() string - MarshalJSON() ([]byte, error) + } + HealthChecker interface { + CheckHealth() (healthy bool, detail string, err error) + URL() types.URL + Config() *HealthCheckConfig + UpdateURL(url types.URL) } HealthCheckFunc func() (healthy bool, detail string, err error) monitor struct { service string config *HealthCheckConfig - url types.URL + url U.AtomicValue[types.URL] status U.AtomicValue[Status] checkHealth HealthCheckFunc startTime time.Time - task common.Task - cancel context.CancelFunc - done chan struct{} - - mu sync.Mutex + task task.Task } ) var monMap = F.NewMapOf[string, HealthMonitor]() -func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { - service := task.Name() - task, cancel := task.SubtaskWithCancel("Health monitor for %s", service) +func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { mon := &monitor{ - service: service, config: config, - url: url, checkHealth: healthCheckFunc, startTime: time.Now(), - task: task, - cancel: cancel, - done: make(chan struct{}), + task: task.DummyTask(), } + mon.url.Store(url) mon.status.Store(StatusHealthy) return mon } -func Inspect(name string) (HealthMonitor, bool) { - return monMap.Load(name) +func Inspect(service string) (HealthMonitor, bool) { + return monMap.Load(service) } -func (mon *monitor) Start() { - defer monMap.Store(mon.task.Name(), mon) +func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cancel context.CancelFunc) { + if mon.task != nil { + return context.WithTimeoutCause(mon.task.Context(), mon.config.Timeout, errors.New(cause)) + } else { + return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause)) + } +} +// Start implements task.TaskStarter. +func (mon *monitor) Start(routeSubtask task.Task) E.NestedError { + mon.service = routeSubtask.Parent().Name() + mon.task = routeSubtask + + if err := mon.checkUpdateHealth(); err != nil { + mon.task.Finish(fmt.Sprintf("healthchecker %s failure: %s", mon.service, err)) + return err + } go func() { - defer close(mon.done) - defer mon.task.Finished() + defer func() { + monMap.Delete(mon.task.Name()) + if mon.status.Load() != StatusError { + mon.status.Store(StatusUnknown) + } + mon.task.Finish(mon.task.FinishCause().Error()) + }() - ok := mon.checkUpdateHealth() - if !ok { - return - } + monMap.Store(mon.service, mon) ticker := time.NewTicker(mon.config.Interval) defer ticker.Stop() @@ -83,48 +98,61 @@ func (mon *monitor) Start() { case <-mon.task.Context().Done(): return case <-ticker.C: - ok = mon.checkUpdateHealth() - if !ok { + err := mon.checkUpdateHealth() + if err != nil { + logger.Errorf("healthchecker %s failure: %s", mon.service, err) return } } } }() + return nil } -func (mon *monitor) Stop() { - monMap.Delete(mon.task.Name()) - - mon.mu.Lock() - defer mon.mu.Unlock() - - if mon.cancel == nil { - return - } - - mon.cancel() - <-mon.done - - mon.cancel = nil - mon.status.Store(StatusUnknown) +// Finish implements task.TaskFinisher. +func (mon *monitor) Finish(reason string) { + mon.task.Finish(reason) } +// UpdateURL implements HealthChecker. +func (mon *monitor) UpdateURL(url types.URL) { + mon.url.Store(url) +} + +// URL implements HealthChecker. +func (mon *monitor) URL() types.URL { + return mon.url.Load() +} + +// Config implements HealthChecker. +func (mon *monitor) Config() *HealthCheckConfig { + return mon.config +} + +// Status implements HealthMonitor. func (mon *monitor) Status() Status { return mon.status.Load() } +// Uptime implements HealthMonitor. func (mon *monitor) Uptime() time.Duration { return time.Since(mon.startTime) } +// Name implements HealthMonitor. func (mon *monitor) Name() string { + if mon.task == nil { + return "" + } return mon.task.Name() } +// String implements fmt.Stringer of HealthMonitor. func (mon *monitor) String() string { return mon.Name() } +// MarshalJSON implements json.Marshaler of HealthMonitor. func (mon *monitor) MarshalJSON() ([]byte, error) { return (&JSONRepresentation{ Name: mon.service, @@ -132,19 +160,19 @@ func (mon *monitor) MarshalJSON() ([]byte, error) { Status: mon.status.Load(), Started: mon.startTime, Uptime: mon.Uptime(), - URL: mon.url, + URL: mon.url.Load(), }).MarshalJSON() } -func (mon *monitor) checkUpdateHealth() (hasError bool) { +func (mon *monitor) checkUpdateHealth() E.NestedError { healthy, detail, err := mon.checkHealth() if err != nil { + defer mon.task.Finish(err.Error()) mon.status.Store(StatusError) if !errors.Is(err, context.Canceled) { - logger.Errorf("%s failed to check health: %s", mon.service, err) + return E.Failure("check health").With(err) } - mon.Stop() - return false + return nil } var status Status if healthy { @@ -160,5 +188,5 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) { } } - return true + return nil } diff --git a/internal/watcher/health/raw.go b/internal/watcher/health/raw.go index b45d4d8..588cd76 100644 --- a/internal/watcher/health/raw.go +++ b/internal/watcher/health/raw.go @@ -3,7 +3,6 @@ package health import ( "net" - "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/net/types" ) @@ -14,9 +13,9 @@ type ( } ) -func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor { +func NewRawHealthMonitor(url types.URL, config *HealthCheckConfig) *RawHealthMonitor { mon := new(RawHealthMonitor) - mon.monitor = newMonitor(task, url, config, mon.checkAvail) + mon.monitor = newMonitor(url, config, mon.CheckHealth) mon.dialer = &net.Dialer{ Timeout: config.Timeout, FallbackDelay: -1, @@ -24,14 +23,22 @@ func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckCon return mon } -func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) { - conn, dialErr := mon.dialer.DialContext(mon.task.Context(), mon.url.Scheme, mon.url.Host) +func NewRawHealthChecker(url types.URL, config *HealthCheckConfig) HealthChecker { + return NewRawHealthMonitor(url, config) +} + +func (mon *RawHealthMonitor) CheckHealth() (healthy bool, detail string, err error) { + ctx, cancel := mon.ContextWithTimeout("ping request timed out") + defer cancel() + + url := mon.url.Load() + conn, dialErr := mon.dialer.DialContext(ctx, url.Scheme, url.Host) if dialErr != nil { detail = dialErr.Error() /* trunk-ignore(golangci-lint/nilerr) */ return } conn.Close() - avail = true + healthy = true return }