From 1ab34ed46fa207a241ccb7214474b7d52d7b7c21 Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 1 Jan 2025 06:07:32 +0800 Subject: [PATCH] simplify task package implementation --- .trunk/trunk.yaml | 6 +- Makefile | 4 +- cmd/main.go | 3 +- internal/api/v1/list.go | 2 +- internal/autocert/provider.go | 4 +- internal/config/config.go | 22 +- internal/docker/client.go | 2 +- internal/docker/idlewatcher/waker.go | 28 +- internal/docker/idlewatcher/watcher.go | 9 +- internal/entrypoint/entrypoint.go | 2 +- internal/error/utils.go | 6 +- internal/net/http/accesslog/config.go | 3 +- internal/net/http/accesslog/file_logger.go | 22 + internal/net/http/accesslog/formatter.go | 27 +- .../net/http/loadbalancer/loadbalancer.go | 15 +- .../http/middleware/errorpage/error_page.go | 12 +- internal/net/http/server/server.go | 9 +- internal/notif/dispatcher.go | 2 +- internal/route/http.go | 41 +- internal/route/provider/event_handler.go | 6 +- internal/route/provider/provider.go | 30 +- internal/route/stream.go | 30 +- internal/route/stream_impl.go | 2 +- internal/task/task.go | 397 ++++++------------ internal/task/task_test.go | 160 +++---- internal/task/utils.go | 96 +++++ internal/utils/atomic.go | 30 -- internal/utils/atomic/atomic_value.go | 30 ++ internal/utils/functional/map.go | 25 +- internal/utils/functional/set.go | 4 + internal/utils/testing/testing.go | 17 +- internal/watcher/config_file_watcher.go | 4 +- internal/watcher/directory_watcher.go | 18 +- internal/watcher/events/event_queue.go | 58 +-- internal/watcher/health/monitor/monitor.go | 21 +- 35 files changed, 547 insertions(+), 600 deletions(-) create mode 100644 internal/net/http/accesslog/file_logger.go create mode 100644 internal/task/utils.go delete mode 100644 internal/utils/atomic.go create mode 100644 internal/utils/atomic/atomic_value.go diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 8f50f86..7c01a4d 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -23,16 +23,16 @@ lint: enabled: - hadolint@2.12.1-beta - actionlint@1.7.4 - - checkov@3.2.334 + - checkov@3.2.344 - git-diff-check - gofmt@1.20.4 - golangci-lint@1.62.2 - - osv-scanner@1.9.1 + - osv-scanner@1.9.2 - oxipng@9.1.3 - prettier@3.4.2 - shellcheck@0.10.0 - shfmt@3.6.0 - - trufflehog@3.86.1 + - trufflehog@3.88.0 actions: disabled: - trunk-announce diff --git a/Makefile b/Makefile index 2e628b5..349d99a 100755 --- a/Makefile +++ b/Makefile @@ -28,10 +28,10 @@ get: go get -u ./cmd && go mod tidy debug: - GODOXY_DEBUG=1 make run + GODOXY_DEBUG=1 BUILD_FLAGS="" make run debug-trace: - GODOXY_DEBUG=1 GODOXY_TRACE=1 run + GODOXY_TRACE=1 make debug profile: GODEBUG=gctrace=1 make debug diff --git a/cmd/main.go b/cmd/main.go index 872b024..ec79e13 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -159,8 +159,7 @@ func main() { // grafully shutdown logging.Info().Msg("shutting down") - task.CancelGlobalContext() - _ = task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown)) + _ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown)) } func prepareDirectory(dir string) { diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index fc1e8d8..02c4b8b 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -52,7 +52,7 @@ func List(w http.ResponseWriter, r *http.Request) { case ListHomepageConfig: U.RespondJSON(w, r, config.HomepageConfig()) case ListTasks: - U.RespondJSON(w, r, task.DebugTaskMap()) + U.RespondJSON(w, r, task.DebugTaskList()) default: U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index de9ded7..2c45aa2 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -153,8 +153,8 @@ func (p *Provider) ScheduleRenewal() { return } go func() { - task := task.GlobalTask("cert renew scheduler") - defer task.Finish("cert renew scheduler stopped") + task := task.RootTask("cert-renew-scheduler", true) + defer task.Finish(nil) for { renewalTime := p.ShouldRenewOn() diff --git a/internal/config/config.go b/internal/config/config.go index cf9001f..1127ef7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -53,7 +53,7 @@ func newConfig() *Config { return &Config{ value: types.DefaultConfig(), providers: F.NewMapOf[string, *proxy.Provider](), - task: task.GlobalTask("config"), + task: task.RootTask("config", false), } } @@ -76,21 +76,19 @@ func MatchDomains() []string { } func WatchChanges() { - task := task.GlobalTask("config watcher") + t := task.RootTask("config_watcher", true) eventQueue := events.NewEventQueue( - task, + t, configEventFlushInterval, OnConfigChange, func(err E.Error) { E.LogError("config reload error", err, &logger) }, ) - eventQueue.Start(cfgWatcher.Events(task.Context())) + eventQueue.Start(cfgWatcher.Events(t.Context())) } -func OnConfigChange(flushTask *task.Task, ev []events.Event) { - defer flushTask.Finish("config reload complete") - +func OnConfigChange(ev []events.Event) { // no matter how many events during the interval // just reload once and check the last event switch ev[len(ev)-1].Action { @@ -116,14 +114,14 @@ func Reload() E.Error { newCfg := newConfig() err := newCfg.load() if err != nil { + newCfg.task.Finish(err) return err } // cancel all current subtasks -> wait // -> replace config -> start new subtasks instance.task.Finish("config changed") - instance.task.Wait() - *instance = *newCfg + instance = newCfg instance.StartProxyProviders() return nil } @@ -143,8 +141,7 @@ func (cfg *Config) Task() *task.Task { func (cfg *Config) StartProxyProviders() { errs := cfg.providers.CollectErrorsParallel( func(_ string, p *proxy.Provider) error { - subtask := cfg.task.Subtask(p.String()) - return p.Start(subtask) + return p.Start(cfg.task) }) if err := E.Join(errs...); err != nil { @@ -209,9 +206,6 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) } func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error { - subtask := cfg.task.Subtask("load route providers") - defer subtask.Finish("done") - errs := E.NewBuilder("route provider errors") results := E.NewBuilder("loaded route providers") diff --git a/internal/docker/client.go b/internal/docker/client.go index f3b176f..7128740 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -38,7 +38,7 @@ var ( ) func init() { - task.GlobalTask("close docker clients").OnFinished("", func() { + task.OnProgramExit("docker_clients_cleanup", func() { clientMap.RangeAllParallel(func(_ string, c Client) { if c.Connected() { c.Client.Close() diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 1ae21eb..ff589fc 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -38,7 +38,7 @@ const ( // TODO: support stream -func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { +func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { hcCfg := entry.RawEntry().HealthCheck hcCfg.Timeout = idleWakerCheckTimeout @@ -46,8 +46,8 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP rp: rp, stream: stream, } - - watcher, err := registerWatcher(providerSubTask, entry, waker) + task := parent.Subtask("idlewatcher") + watcher, err := registerWatcher(task, entry, waker) if err != nil { return nil, E.Errorf("register watcher: %w", err) } @@ -63,7 +63,7 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP if common.PrometheusEnabled { m := metrics.GetServiceMetrics() - fqn := providerSubTask.Parent().Name() + "/" + entry.TargetName() + fqn := parent.Name() + "/" + entry.TargetName() waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn)) waker.metric.Set(float64(watcher.Status())) } @@ -71,19 +71,18 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP } // lifetime should follow route provider. -func NewHTTPWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { - return newWaker(providerSubTask, entry, rp, nil) +func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { + return newWaker(parent, entry, rp, nil) } -func NewStreamWaker(providerSubTask *task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) { - return newWaker(providerSubTask, entry, nil, stream) +func NewStreamWaker(parent task.Parent, entry route.Entry, stream net.Stream) (Waker, E.Error) { + return newWaker(parent, entry, nil, stream) } // Start implements health.HealthMonitor. -func (w *Watcher) Start(routeSubTask *task.Task) E.Error { - routeSubTask.Finish("ignored") - w.task.OnCancel("stop route and cleanup", func() { - routeSubTask.Parent().Finish(w.task.FinishCause()) +func (w *Watcher) Start(parent task.Parent) E.Error { + w.task.OnCancel("route_cleanup", func() { + parent.Finish(w.task.FinishCause()) if w.metric != nil { w.metric.Reset() } @@ -91,6 +90,11 @@ func (w *Watcher) Start(routeSubTask *task.Task) E.Error { return nil } +// Task implements health.HealthMonitor. +func (w *Watcher) Task() *task.Task { + return w.task +} + // Finish implements health.HealthMonitor. func (w *Watcher) Finish(reason any) { if w.stream != nil { diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index 17afcb8..aac0eee 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -51,7 +51,7 @@ var ( const dockerReqTimeout = 3 * time.Second -func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) { +func registerWatcher(watcherTask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) { cfg := entry.IdlewatcherConfig() if cfg.IdleTimeout == 0 { @@ -67,7 +67,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker w.Config = cfg w.waker = waker w.resetIdleTimer() - providerSubtask.Finish("used existing watcher") + watcherTask.Finish("used existing watcher") return w, nil } @@ -81,7 +81,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker Config: cfg, waker: waker, client: client, - task: providerSubtask, + task: watcherTask, ticker: time.NewTicker(cfg.IdleTimeout), } w.stopByMethod = w.getStopCallback() @@ -210,8 +210,7 @@ func (w *Watcher) resetIdleTimer() { } func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) { - eventTask = w.task.Subtask("docker event watcher") - eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{ + eventCh, errCh = dockerWatcher.EventsWithOptions(w.Task().Context(), watcher.DockerListOptions{ Filters: watcher.NewDockerFilter( watcher.DockerFilterContainer, watcher.DockerFilterContainerNameID(w.ContainerID), diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index a16f454..82012ad 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -54,7 +54,7 @@ func SetMiddlewares(mws []map[string]any) error { return nil } -func SetAccessLogger(parent *task.Task, cfg *accesslog.Config) (err error) { +func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) { epAccessLoggerMu.Lock() defer epAccessLoggerMu.Unlock() diff --git a/internal/error/utils.go b/internal/error/utils.go index 05179c6..3bb5e70 100644 --- a/internal/error/utils.go +++ b/internal/error/utils.go @@ -50,10 +50,12 @@ func Join(errors ...error) Error { if n == 0 { return nil } - errs := make([]error, 0, n) + errs := make([]error, n) + i := 0 for _, err := range errors { if err != nil { - errs = append(errs, err) + errs[i] = err + i++ } } return &nestedError{Extras: errs} diff --git a/internal/net/http/accesslog/config.go b/internal/net/http/accesslog/config.go index 69115c7..23d96c8 100644 --- a/internal/net/http/accesslog/config.go +++ b/internal/net/http/accesslog/config.go @@ -22,6 +22,7 @@ type ( Path string `json:"path" validate:"required"` Filters Filters `json:"filters"` Fields Fields `json:"fields"` + // Retention *Retention } ) @@ -31,7 +32,7 @@ var ( FormatJSON Format = "json" ) -const DefaultBufferSize = 100 +const DefaultBufferSize = 64 * 1024 // 64KB func DefaultConfig() *Config { return &Config{ diff --git a/internal/net/http/accesslog/file_logger.go b/internal/net/http/accesslog/file_logger.go new file mode 100644 index 0000000..aeaa241 --- /dev/null +++ b/internal/net/http/accesslog/file_logger.go @@ -0,0 +1,22 @@ +package accesslog + +import ( + "fmt" + "os" + "sync" + + "github.com/yusing/go-proxy/internal/task" +) + +type File struct { + *os.File + sync.Mutex +} + +func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) { + f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return nil, fmt.Errorf("access log open error: %w", err) + } + return NewAccessLogger(parent, &File{File: f}, cfg), nil +} diff --git a/internal/net/http/accesslog/formatter.go b/internal/net/http/accesslog/formatter.go index 9ad9350..0a02b96 100644 --- a/internal/net/http/accesslog/formatter.go +++ b/internal/net/http/accesslog/formatter.go @@ -7,18 +7,17 @@ import ( "net/http" "net/url" "strconv" + "time" ) type ( CommonFormatter struct { - cfg *Fields - } - CombinedFormatter struct { - CommonFormatter - } - JSONFormatter struct { - CommonFormatter + cfg *Fields + GetTimeNow func() time.Time // for testing purposes only } + CombinedFormatter CommonFormatter + JSONFormatter CommonFormatter + JSONLogEntry struct { Time string `json:"time"` IP string `json:"ip"` @@ -39,6 +38,8 @@ type ( } ) +const LogTimeFormat = "02/Jan/2006:15:04:05 -0700" + func scheme(req *http.Request) string { if req.TLS != nil { return "https" @@ -62,7 +63,7 @@ func clientIP(req *http.Request) string { return req.RemoteAddr } -func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { +func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { query := f.cfg.Query.ProcessQuery(req.URL.Query()) line.WriteString(req.Host) @@ -71,7 +72,7 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http line.WriteString(clientIP(req)) line.WriteString(" - - [") - line.WriteString(timeNow()) + line.WriteString(f.GetTimeNow().Format(LogTimeFormat)) line.WriteString("] \"") line.WriteString(req.Method) @@ -86,8 +87,8 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http line.WriteString(strconv.FormatInt(res.ContentLength, 10)) } -func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { - f.CommonFormatter.Format(line, req, res) +func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { + (*CommonFormatter)(f).Format(line, req, res) line.WriteString(" \"") line.WriteString(req.Referer()) line.WriteString("\" \"") @@ -95,14 +96,14 @@ func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *ht line.WriteRune('"') } -func (f JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { +func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { query := f.cfg.Query.ProcessQuery(req.URL.Query()) headers := f.cfg.Headers.ProcessHeaders(req.Header) headers.Del("Cookie") cookies := f.cfg.Cookies.ProcessCookies(req.Cookies()) entry := JSONLogEntry{ - Time: timeNow(), + Time: f.GetTimeNow().Format(LogTimeFormat), IP: clientIP(req), Method: req.Method, Scheme: scheme(req), diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index 26192b7..b79f208 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -9,6 +9,7 @@ import ( "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" + "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health/monitor" @@ -52,10 +53,13 @@ func New(cfg *Config) *LoadBalancer { } // Start implements task.TaskStarter. -func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error { +func (lb *LoadBalancer) Start(parent task.Parent) E.Error { lb.startTime = time.Now() - lb.task = routeSubtask - lb.task.OnFinished("loadbalancer cleanup", func() { + lb.task = parent.Subtask("loadbalancer."+lb.Link, false) + parent.OnCancel("lb_remove_route", func() { + routes.DeleteHTTPRoute(lb.Link) + }) + lb.task.OnFinished("cleanup", func() { if lb.impl != nil { lb.pool.RangeAll(func(k string, v *Server) { lb.impl.OnRemoveServer(v) @@ -66,6 +70,11 @@ func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error { return nil } +// Task implements task.TaskStarter. +func (lb *LoadBalancer) Task() *task.Task { + return lb.task +} + // Finish implements task.TaskFinisher. func (lb *LoadBalancer) Finish(reason any) { lb.task.Finish(reason) diff --git a/internal/net/http/middleware/errorpage/error_page.go b/internal/net/http/middleware/errorpage/error_page.go index 113a467..a7eb003 100644 --- a/internal/net/http/middleware/errorpage/error_page.go +++ b/internal/net/http/middleware/errorpage/error_page.go @@ -32,10 +32,10 @@ func setup() { return } - task := task.GlobalTask("error page") - dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath) + t := task.RootTask("error_page", true) + dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath) loadContent() - go watchDir(task) + go watchDir() } func GetStaticFile(filename string) ([]byte, bool) { @@ -73,11 +73,11 @@ func loadContent() { } } -func watchDir(task *task.Task) { - eventCh, errCh := dirWatcher.Events(task.Context()) +func watchDir() { + eventCh, errCh := dirWatcher.Events(task.RootContext()) for { select { - case <-task.Context().Done(): + case <-task.RootContextCanceled(): return case event, ok := <-eventCh: if !ok { diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index b38faf4..fde17fd 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -24,8 +24,6 @@ type Server struct { httpsStarted bool startTime time.Time - task *task.Task - l zerolog.Logger } @@ -76,7 +74,6 @@ func NewServer(opt Options) (s *Server) { CertProvider: opt.CertProvider, http: httpSer, https: httpsSer, - task: task.GlobalTask(opt.Name + " server"), l: logger, } } @@ -108,7 +105,7 @@ func (s *Server) Start() { s.l.Info().Str("addr", s.https.Addr).Msgf("server started") } - s.task.OnFinished("stop server", s.stop) + task.OnProgramExit("server."+s.Name+".stop", s.stop) } func (s *Server) stop() { @@ -117,12 +114,12 @@ func (s *Server) stop() { } if s.http != nil && s.httpStarted { - s.handleErr("http", s.http.Shutdown(s.task.Context())) + s.handleErr("http", s.http.Shutdown(task.RootContext())) s.httpStarted = false } if s.https != nil && s.httpsStarted { - s.handleErr("https", s.https.Shutdown(s.task.Context())) + s.handleErr("https", s.https.Shutdown(task.RootContext())) s.httpsStarted = false } } diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index 7faba08..b354af4 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -35,7 +35,7 @@ var ( const dispatchErr = "notification dispatch error" -func StartNotifDispatcher(parent *task.Task) *Dispatcher { +func StartNotifDispatcher(parent task.Parent) *Dispatcher { dispatcher = &Dispatcher{ task: parent.Subtask("notification"), logCh: make(chan *LogMessage), diff --git a/internal/route/http.go b/internal/route/http.go index 6d68c5a..67c4718 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -73,19 +73,17 @@ func (r *HTTPRoute) String() string { return r.TargetName() } -// Start implements*task.TaskStarter. -func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error { +// Start implements task.TaskStarter. +func (r *HTTPRoute) Start(parent task.Parent) E.Error { if entry.ShouldNotServe(r) { - providerSubtask.Finish("should not serve") return nil } - r.task = providerSubtask + r.task = parent.Subtask("http."+r.TargetName(), false) switch { case entry.UseIdleWatcher(r): - wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName()) - waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp) + waker, err := idlewatcher.NewHTTPWaker(r.task, r.ReverseProxyEntry, r.rp) if err != nil { r.task.Finish(err) return err @@ -98,7 +96,7 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error { if err == nil { fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck) r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback) - r.task.OnCancel("close docker client", client.Close) + r.task.OnCancel("close_docker_client", client.Close) } } if r.HealthMon == nil { @@ -137,29 +135,32 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error { } if r.HealthMon != nil { - healthMonTask := r.task.Subtask("health monitor") - if err := r.HealthMon.Start(healthMonTask); err != nil { + if err := r.HealthMon.Start(r.task); err != nil { E.LogWarn("health monitor error", err, &r.l) - healthMonTask.Finish(err) } } if entry.UseLoadBalance(r) { - r.addToLoadBalancer() + r.addToLoadBalancer(parent) } else { routes.SetHTTPRoute(r.TargetName(), r) - r.task.OnFinished("remove from route table", func() { + r.task.OnFinished("entrypoint_remove_route", func() { routes.DeleteHTTPRoute(r.TargetName()) }) } if common.PrometheusEnabled { - r.task.OnFinished("unreg metrics", r.rp.UnregisterMetrics) + r.task.OnFinished("metrics_cleanup", r.rp.UnregisterMetrics) } return nil } -// Finish implements*task.TaskFinisher. +// Task implements task.TaskStarter. +func (r *HTTPRoute) Task() *task.Task { + return r.task +} + +// Finish implements task.TaskFinisher. func (r *HTTPRoute) Finish(reason any) { r.task.Finish(reason) } @@ -168,7 +169,7 @@ func (r *HTTPRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.handler.ServeHTTP(w, req) } -func (r *HTTPRoute) addToLoadBalancer() { +func (r *HTTPRoute) addToLoadBalancer(parent task.Parent) { var lb *loadbalancer.LoadBalancer cfg := r.Raw.LoadBalance l, ok := routes.GetHTTPRoute(cfg.Link) @@ -182,11 +183,7 @@ func (r *HTTPRoute) addToLoadBalancer() { } } else { lb = loadbalancer.New(cfg) - lbTask := r.task.Parent().Subtask("loadbalancer " + cfg.Link) - lbTask.OnCancel("remove lb from routes", func() { - routes.DeleteHTTPRoute(cfg.Link) - }) - if err := lb.Start(lbTask); err != nil { + if err := lb.Start(parent); err != nil { panic(err) // should always return nil } linked = &HTTPRoute{ @@ -203,9 +200,9 @@ func (r *HTTPRoute) addToLoadBalancer() { routes.SetHTTPRoute(cfg.Link, linked) } r.loadBalancer = lb - r.server = loadbalance.NewServer(r.task.String(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon) + r.server = loadbalance.NewServer(r.task.Name(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon) lb.AddServer(r.server) - r.task.OnCancel("remove server from lb", func() { + r.task.OnCancel("lb_remove_server", func() { lb.RemoveServer(r.server) }) } diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 653db11..46a95b6 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -28,7 +28,7 @@ func (p *Provider) newEventHandler() *EventHandler { } } -func (handler *EventHandler) Handle(parent *task.Task, events []watcher.Event) { +func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) { oldRoutes := handler.provider.routes newRoutes, err := handler.provider.loadRoutesImpl() if err != nil { @@ -97,7 +97,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool return false } -func (handler *EventHandler) Add(parent *task.Task, route *route.Route) { +func (handler *EventHandler) Add(parent task.Parent, route *route.Route) { err := handler.provider.startRoute(parent, route) if err != nil { handler.errs.Add(err.Subject("add")) @@ -112,7 +112,7 @@ func (handler *EventHandler) Remove(route *route.Route) { handler.removed.Adds(route.Entry.Alias) } -func (handler *EventHandler) Update(parent *task.Task, oldRoute *route.Route, newRoute *route.Route) { +func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) { oldRoute.Finish("route update") err := handler.provider.startRoute(parent, newRoute) if err != nil { diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index cb09fb2..a834ac0 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "path" + "strings" "time" "github.com/rs/zerolog" @@ -60,7 +61,7 @@ func NewFileProvider(filename string) (p *Provider, err error) { if name == "" { return nil, ErrEmptyProviderName } - p = newProvider(name, ProviderTypeFile) + p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile) p.ProviderImpl, err = FileProviderImpl(filename) if err != nil { return nil, err @@ -100,46 +101,43 @@ func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } -func (p *Provider) startRoute(parent *task.Task, r *R.Route) E.Error { - subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias) - err := r.Start(subtask) +func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error { + err := r.Start(parent) if err != nil { - p.routes.Delete(r.Entry.Alias) - subtask.Finish(err) // just to ensure return err.Subject(r.Entry.Alias) } + p.routes.Store(r.Entry.Alias, r) - subtask.OnFinished("del from provider", func() { + r.Task().OnFinished("provider_remove_route", func() { p.routes.Delete(r.Entry.Alias) }) return nil } // Start implements*task.TaskStarter. -func (p *Provider) Start(configSubtask *task.Task) E.Error { - // routes and event queue will stop on parent cancel - providerTask := configSubtask +func (p *Provider) Start(parent task.Parent) E.Error { + t := parent.Subtask("provider."+p.name, false) + // routes and event queue will stop on config reload errs := p.routes.CollectErrorsParallel( func(alias string, r *R.Route) error { - return p.startRoute(providerTask, r) + return p.startRoute(t, r) }) eventQueue := events.NewEventQueue( - providerTask, + t.Subtask("event_queue", false), providerEventFlushInterval, - func(flushTask *task.Task, events []events.Event) { + func(events []events.Event) { handler := p.newEventHandler() // routes' lifetime should follow the provider's lifetime - handler.Handle(providerTask, events) + handler.Handle(t, events) handler.Log() - flushTask.Finish("events flushed") }, func(err E.Error) { E.LogError("event error", err, p.Logger()) }, ) - eventQueue.Start(p.watcher.Events(providerTask.Context())) + eventQueue.Start(p.watcher.Events(t.Context())) if err := E.Join(errs...); err != nil { return err.Subject(p.String()) diff --git a/internal/route/stream.go b/internal/route/stream.go index 065adf9..3560e86 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -47,20 +47,22 @@ func (r *StreamRoute) String() string { return "stream " + r.TargetName() } -// Start implements*task.TaskStarter. -func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error { +// Start implements task.TaskStarter. +func (r *StreamRoute) Start(parent task.Parent) E.Error { if entry.ShouldNotServe(r) { - providerSubtask.Finish("should not serve") return nil } - r.task = providerSubtask + r.task = parent.Subtask("stream." + r.TargetName()) r.Stream = NewStream(r) + parent.OnCancel("finish", func() { + r.task.Finish(nil) + }) + switch { case entry.UseIdleWatcher(r): - wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName()) - waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream) + waker, err := idlewatcher.NewStreamWaker(r.task, r.StreamEntry, r.Stream) if err != nil { r.task.Finish(err) return err @@ -73,7 +75,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error { if err == nil { fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck) r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback) - r.task.OnCancel("close docker client", client.Close) + r.task.OnCancel("close_docker_client", client.Close) } } if r.HealthMon == nil { @@ -86,7 +88,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error { return E.From(err) } - r.task.OnFinished("close stream", func() { + r.task.OnFinished("close_stream", func() { if err := r.Stream.Close(); err != nil { E.LogError("close stream failed", err, &r.l) } @@ -97,22 +99,26 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error { Msg("listening") if r.HealthMon != nil { - healthMonTask := r.task.Subtask("health monitor") - if err := r.HealthMon.Start(healthMonTask); err != nil { + if err := r.HealthMon.Start(r.task); err != nil { E.LogWarn("health monitor error", err, &r.l) - healthMonTask.Finish(err) } } go r.acceptConnections() routes.SetStreamRoute(r.TargetName(), r) - r.task.OnFinished("remove from route table", func() { + r.task.OnFinished("entrypoint_remove_route", func() { routes.DeleteStreamRoute(r.TargetName()) }) return nil } +// Task implements task.TaskStarter. +func (r *StreamRoute) Task() *task.Task { + return r.task +} + +// Finish implements task.TaskFinisher. func (r *StreamRoute) Finish(reason any) { r.task.Finish(reason) } diff --git a/internal/route/stream_impl.go b/internal/route/stream_impl.go index e5ebeae..5908bea 100644 --- a/internal/route/stream_impl.go +++ b/internal/route/stream_impl.go @@ -95,7 +95,7 @@ func (stream *Stream) Handle(conn types.StreamConn) error { return fmt.Errorf("unexpected listener type: %T", stream) } case io.ReadWriteCloser: - stream.task.OnCancel("close conn", func() { conn.Close() }) + stream.task.OnCancel("close_conn", func() { conn.Close() }) dialer := &net.Dialer{Timeout: streamDialTimeout} dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) diff --git a/internal/task/task.go b/internal/task/task.go index 0d0cedc..2ed1597 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -4,355 +4,194 @@ import ( "context" "errors" "fmt" - "strings" + "runtime/debug" "sync" - "time" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/logging" - F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" ) -var globalTask = createGlobalTask() - -func createGlobalTask() (t *Task) { - t = new(Task) - t.name = "root" - t.ctx, t.cancel = context.WithCancelCause(context.Background()) - t.subtasks = F.NewSet[*Task]() - return -} - -func testResetGlobalTask() { - globalTask = createGlobalTask() -} - type ( TaskStarter interface { // Start starts the object that implements TaskStarter, // and returns an error if it fails to start. // - // The task passed must be a subtask of the caller task. - // // callerSubtask.Finish must be called when start fails or the object is finished. - Start(callerSubtask *Task) E.Error + Start(parent Parent) E.Error + Task() *Task } TaskFinisher interface { - // Finish marks the task as finished and cancel its context. - // - // Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished - // of the task to finish. - // - // Note that it will also cancel all subtasks. Finish(reason any) } // Task controls objects' lifetime. // // Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface. // - // When passing a Task object to another function, - // it must be a sub-Task of the current Task, - // in name of "`currentTaskName`Subtask" - // // Use Task.Finish to stop all subtasks of the Task. Task struct { + name string + + children sync.WaitGroup + + onFinished sync.WaitGroup + finished chan struct{} + ctx context.Context cancel context.CancelCauseFunc - parent *Task - subtasks F.Set[*Task] - subTasksWg sync.WaitGroup - - name string - - OnFinishedFuncs []func() - OnFinishedMu sync.Mutex - onFinishedWg sync.WaitGroup - - finishOnce sync.Once + once sync.Once + } + Parent interface { + Context() context.Context + Subtask(name string, needFinish ...bool) *Task + Name() string + Finish(reason any) + OnCancel(name string, f func()) } ) -var ( - ErrProgramExiting = errors.New("program exiting") - ErrTaskCanceled = errors.New("task canceled") - - logger = logging.With().Str("module", "task").Logger() -) - -// GlobalTask returns a new Task with the given name, derived from the global context. -func GlobalTask(format string, args ...any) *Task { - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } - return globalTask.Subtask(format) -} - -// DebugTaskMap returns a map[string]any representation of the global task tree. -// -// The returned map is suitable for encoding to JSON, and can be used -// to debug the task tree. -// -// The returned map is not guaranteed to be stable, and may change -// between runs of the program. It is intended for debugging purposes -// only. -func DebugTaskMap() map[string]any { - return globalTask.serialize() -} - -// CancelGlobalContext cancels the global task context, which will cause all tasks -// created to be canceled. This should be called before exiting the program -// to ensure that all tasks are properly cleaned up. -func CancelGlobalContext() { - globalTask.cancel(ErrProgramExiting) -} - -// GlobalContextWait waits for all tasks to finish, up to the given timeout. -// -// If the timeout is exceeded, it prints a list of all tasks that were -// still running when the timeout was reached, and their current tree -// of subtasks. -func GlobalContextWait(timeout time.Duration) (err error) { - done := make(chan struct{}) - after := time.After(timeout) - go func() { - globalTask.Wait() - close(done) - }() - for { - select { - case <-done: - return - case <-after: - logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree()) - return context.DeadlineExceeded - } - } -} - -func (t *Task) trace(msg string) { - logger.Trace().Str("name", t.name).Msg(msg) -} - -// Name returns the name of the task. -func (t *Task) Name() string { - if !common.IsTrace { - return t.name - } - parts := strings.Split(t.name, " > ") - return parts[len(parts)-1] -} - -// String returns the name of the task. -func (t *Task) String() string { - return t.name -} - -// Context returns the context associated with the task. This context is -// canceled when Finish of the task is called, or parent task is canceled. func (t *Task) Context() context.Context { return t.ctx } +func (t *Task) Finished() <-chan struct{} { + return t.finished +} + // FinishCause returns the reason / error that caused the task to be finished. func (t *Task) FinishCause() error { - cause := context.Cause(t.ctx) - if cause == nil { - return t.ctx.Err() - } - return cause + return context.Cause(t.ctx) } -// Parent returns the parent task of the current task. -func (t *Task) Parent() *Task { - return t.parent -} - -func (t *Task) runAllOnFinished(onCompTask *Task) { - <-t.ctx.Done() - t.WaitSubTasks() - for _, OnFinishedFunc := range t.OnFinishedFuncs { - OnFinishedFunc() - t.onFinishedWg.Done() - } - onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done")) -} - -// OnFinished calls fn when all subtasks are finished. +// OnFinished calls fn when the task is canceled and all subtasks are finished. // -// It cannot be called after Finish or Wait is called. +// It should not be called after Finish is called. func (t *Task) OnFinished(about string, fn func()) { - if t.parent == globalTask { - t.OnCancel(about, fn) - return - } - t.onFinishedWg.Add(1) - t.OnFinishedMu.Lock() - defer t.OnFinishedMu.Unlock() - - if t.OnFinishedFuncs == nil { - onCompTask := GlobalTask(t.name + " > OnFinished > " + about) - go t.runAllOnFinished(onCompTask) - } - idx := len(t.OnFinishedFuncs) - wrapped := func() { - defer func() { - if err := recover(); err != nil { - logger.Error(). - Str("name", t.name). - Interface("err", err). - Msg("panic in " + about) - } - }() - fn() - logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about) - } - t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped) + t.onCancel(about, fn, true) } // OnCancel calls fn when the task is canceled. // -// It cannot be called after Finish or Wait is called. +// It should not be called after Finish is called. func (t *Task) OnCancel(about string, fn func()) { - onCompTask := GlobalTask(t.name + " > OnFinished") + t.onCancel(about, fn, false) +} + +func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) { + t.onFinished.Add(1) go func() { <-t.ctx.Done() - fn() - onCompTask.Finish("done") - t.trace("onCancel done: " + about) + if waitSubTasks { + t.children.Wait() + } + t.invokeWithRecover(fn, about) + t.onFinished.Done() }() } -// Finish marks the task as finished and cancel its context. -// -// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished -// of the task to finish. -// -// Note that it will also cancel all subtasks. +// Finish cancel all subtasks and wait for them to finish, +// then marks the task as finished, with the given reason (if any). func (t *Task) Finish(reason any) { - var format string - switch reason.(type) { - case error: - format = "%w" - case string, fmt.Stringer: - format = "%s" + select { + case <-t.finished: + return default: - format = "%v" + t.once.Do(func() { + t.finish(reason) + }) + } +} + +func (t *Task) finish(reason any) { + t.cancel(fmtCause(reason)) + t.children.Wait() + t.onFinished.Wait() + if t.finished != nil { + close(t.finished) + } + logger.Trace().Msg("task " + t.name + " finished") +} + +func fmtCause(cause any) error { + switch cause := cause.(type) { + case nil: + return nil + case error: + return cause + case string: + return errors.New(cause) + default: + return fmt.Errorf("%v", cause) } - t.finishOnce.Do(func() { - t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason)) - }) - t.Wait() } // Subtask returns a new subtask with the given name, derived from the parent's context. // -// If the parent's context is already canceled, the returned subtask will be canceled immediately. -// -// This should not be called after Finish, Wait, or WaitSubTasks is called. -func (t *Task) Subtask(name string) *Task { - ctx, cancel := context.WithCancelCause(t.ctx) - return t.newSubTask(ctx, cancel, name) -} +// This should not be called after Finish is called. +func (t *Task) Subtask(name string, needFinish ...bool) *Task { + nf := len(needFinish) == 0 || needFinish[0] -func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task { - parent := t - if common.IsTrace { - name = parent.name + " > " + name - } - subtask := &Task{ + ctx, cancel := context.WithCancelCause(t.ctx) + child := &Task{ + finished: make(chan struct{}, 1), ctx: ctx, cancel: cancel, - name: name, - parent: parent, - subtasks: F.NewSet[*Task](), } - parent.subTasksWg.Add(1) - parent.subtasks.Add(subtask) - if common.IsTrace { - subtask.trace("started") + if t != root { + child.name = t.name + "." + name + allTasks.Add(child) + } else { + child.name = name + } + + allTasksWg.Add(1) + t.children.Add(1) + + if !nf { go func() { - subtask.Wait() - subtask.trace("finished: " + subtask.FinishCause().Error()) + <-child.ctx.Done() + child.Finish(nil) }() } + go func() { - subtask.Wait() - parent.subtasks.Remove(subtask) - parent.subTasksWg.Done() + <-child.finished + allTasksWg.Done() + t.children.Done() + allTasks.Remove(child) }() - return subtask + + logger.Trace().Msg("task " + child.name + " started") + return child } -// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish. -// -// It must be called only after Finish is called. -func (t *Task) Wait() { - <-t.ctx.Done() - t.WaitSubTasks() - t.onFinishedWg.Wait() +// Name returns the name of the task without parent names. +func (t *Task) Name() string { + parts := strutils.SplitRune(t.name, '.') + return parts[len(parts)-1] } -// WaitSubTasks waits for all subtasks of the task to finish. -// -// No more subtasks can be added after this call. -// -// It can be called before Finish is called. -func (t *Task) WaitSubTasks() { - t.subTasksWg.Wait() +// String returns the full name of the task. +func (t *Task) String() string { + return t.name } -// tree returns a string representation of the task tree, with the given -// prefix prepended to each line. The prefix is used to indent the tree, -// and should be a string of spaces or a similar separator. -// -// The resulting string is suitable for printing to the console, and can be -// used to debug the task tree. -// -// The tree is traversed in a depth-first manner, with each task's name and -// line number (if available) printed on a separate line. The line number is -// only printed if the task was created with a non-empty line argument. -// -// The returned string is not guaranteed to be stable, and may change between -// runs of the program. It is intended for debugging purposes only. -func (t *Task) tree(prefix ...string) string { - var sb strings.Builder - var pre string - if len(prefix) > 0 { - pre = prefix[0] - sb.WriteString(pre + "- ") - } - sb.WriteString(t.Name() + "\n") - t.subtasks.RangeAll(func(subtask *Task) { - sb.WriteString(subtask.tree(pre + " ")) - }) - return sb.String() +// MarshalText implements encoding.TextMarshaler. +func (t *Task) MarshalText() ([]byte, error) { + return []byte(t.name), nil } -// serialize returns a map[string]any representation of the task tree. -// -// The map contains the following keys: -// - name: the name of the task -// - subtasks: a slice of maps, each representing a subtask -// -// The subtask maps contain the same keys, recursively. -// -// The returned map is suitable for encoding to JSON, and can be used -// to debug the task tree. -// -// The returned map is not guaranteed to be stable, and may change -// between runs of the program. It is intended for debugging purposes -// only. -func (t *Task) serialize() map[string]any { - m := make(map[string]any) - parts := strings.Split(t.name, " > ") - m["name"] = parts[len(parts)-1] - if t.subtasks.Size() > 0 { - m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size()) - t.subtasks.RangeAll(func(subtask *Task) { - m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize()) - }) - } - return m +func (t *Task) invokeWithRecover(fn func(), caller string) { + defer func() { + if err := recover(); err != nil { + logger.Error(). + Interface("err", err). + Msg("panic in task " + t.name + "." + caller) + if common.IsDebug { + panic(string(debug.Stack())) + } + } + }() + fn() } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 752343d..c5b866e 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -2,132 +2,112 @@ package task import ( "context" - "sync/atomic" + "sync" "testing" "time" . "github.com/yusing/go-proxy/internal/utils/testing" ) -const ( - rootTaskName = "root-task" - subTaskName = "subtask" -) - -func TestTaskCreation(t *testing.T) { - rootTask := GlobalTask(rootTaskName) - subTask := rootTask.Subtask(subTaskName) - - ExpectEqual(t, rootTaskName, rootTask.Name()) - ExpectEqual(t, subTaskName, subTask.Name()) +func testTask() *Task { + return RootTask("test", false) } -func TestTaskCancellation(t *testing.T) { - subTaskDone := make(chan struct{}) +func TestChildTaskCancellation(t *testing.T) { + t.Cleanup(testCleanup) - rootTask := GlobalTask(rootTaskName) - subTask := rootTask.Subtask(subTaskName) + parent := testTask() + child := parent.Subtask("") go func() { - subTask.Wait() - close(subTaskDone) + defer child.Finish(nil) + for { + select { + case <-child.Context().Done(): + return + default: + continue + } + } }() - go rootTask.Finish(nil) + parent.cancel(nil) // should also cancel child select { - case <-subTaskDone: - err := subTask.Context().Err() - ExpectError(t, context.Canceled, err) - cause := context.Cause(subTask.Context()) - ExpectError(t, ErrTaskCanceled, cause) - case <-time.After(1 * time.Second): + case <-child.Finished(): + ExpectError(t, context.Canceled, child.Context().Err()) + default: t.Fatal("subTask context was not canceled as expected") } } -func TestOnComplete(t *testing.T) { - rootTask := GlobalTask(rootTaskName) - task := rootTask.Subtask(subTaskName) +func TestTaskOnCancelOnFinished(t *testing.T) { + t.Cleanup(testCleanup) + task := testTask() - var value atomic.Int32 - task.OnFinished("set value", func() { - value.Store(1234) + var shouldTrueOnCancel bool + var shouldTrueOnFinish bool + + task.OnCancel("", func() { + shouldTrueOnCancel = true }) + task.OnFinished("", func() { + shouldTrueOnFinish = true + }) + + ExpectFalse(t, shouldTrueOnFinish) task.Finish(nil) - ExpectEqual(t, value.Load(), 1234) + ExpectTrue(t, shouldTrueOnCancel) + ExpectTrue(t, shouldTrueOnFinish) } -func TestGlobalContextWait(t *testing.T) { - testResetGlobalTask() - defer CancelGlobalContext() +func TestCommonFlowWithGracefulShutdown(t *testing.T) { + t.Cleanup(testCleanup) + task := testTask() - rootTask := GlobalTask(rootTaskName) + finished := false - finished1, finished2 := false, false - - subTask1 := rootTask.Subtask(subTaskName) - subTask2 := rootTask.Subtask(subTaskName) - subTask1.OnFinished("", func() { - finished1 = true - }) - subTask2.OnFinished("", func() { - finished2 = true + task.OnFinished("", func() { + finished = true }) go func() { - time.Sleep(500 * time.Millisecond) - subTask1.Finish(nil) + defer task.Finish(nil) + for { + select { + case <-task.Context().Done(): + return + default: + continue + } + } }() - go func() { - time.Sleep(500 * time.Millisecond) - subTask2.Finish(nil) - }() + ExpectNoError(t, GracefulShutdown(1*time.Second)) + ExpectTrue(t, finished) - go func() { - subTask1.Wait() - subTask2.Wait() - rootTask.Finish(nil) - }() - - _ = GlobalContextWait(1 * time.Second) - ExpectTrue(t, finished1) - ExpectTrue(t, finished2) - ExpectError(t, context.Canceled, rootTask.Context().Err()) - ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context())) - ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context())) + <-root.finished + ExpectError(t, context.Canceled, task.Context().Err()) + ExpectError(t, ErrProgramExiting, task.FinishCause()) } -func TestTimeoutOnGlobalContextWait(t *testing.T) { - testResetGlobalTask() +func TestTimeoutOnGracefulShutdown(t *testing.T) { + t.Cleanup(testCleanup) + _ = testTask() - rootTask := GlobalTask(rootTaskName) - rootTask.Subtask(subTaskName) - - ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond)) + ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond)) } -func TestGlobalContextCancellation(t *testing.T) { - testResetGlobalTask() - - taskDone := make(chan struct{}) - rootTask := GlobalTask(rootTaskName) - - go func() { - rootTask.Wait() - close(taskDone) - }() - - CancelGlobalContext() - - select { - case <-taskDone: - err := rootTask.Context().Err() - ExpectError(t, context.Canceled, err) - cause := context.Cause(rootTask.Context()) - ExpectError(t, ErrProgramExiting, cause) - case <-time.After(1 * time.Second): - t.Fatal("subTask context was not canceled as expected") +func TestFinishMultipleCalls(t *testing.T) { + t.Cleanup(testCleanup) + task := testTask() + var wg sync.WaitGroup + wg.Add(5) + for range 5 { + go func() { + defer wg.Done() + task.Finish(nil) + }() } + wg.Wait() } diff --git a/internal/task/utils.go b/internal/task/utils.go new file mode 100644 index 0000000..79c8fce --- /dev/null +++ b/internal/task/utils.go @@ -0,0 +1,96 @@ +package task + +import ( + "context" + "encoding/json" + "errors" + "slices" + "sync" + "time" + + "github.com/yusing/go-proxy/internal/logging" + F "github.com/yusing/go-proxy/internal/utils/functional" +) + +var ErrProgramExiting = errors.New("program exiting") + +var logger = logging.With().Str("module", "task").Logger() + +var root = newRoot() +var allTasks = F.NewSet[*Task]() +var allTasksWg sync.WaitGroup + +func testCleanup() { + root = newRoot() + allTasks.Clear() + allTasksWg = sync.WaitGroup{} +} + +// RootTask returns a new Task with the given name, derived from the root context. +func RootTask(name string, needFinish bool) *Task { + return root.Subtask(name, needFinish) +} + +func newRoot() *Task { + t := &Task{name: "root"} + t.ctx, t.cancel = context.WithCancelCause(context.Background()) + return t +} + +func RootContext() context.Context { + return root.ctx +} + +func RootContextCanceled() <-chan struct{} { + return root.ctx.Done() +} + +func OnProgramExit(about string, fn func()) { + root.OnFinished(about, fn) +} + +// GracefulShutdown waits for all tasks to finish, up to the given timeout. +// +// If the timeout is exceeded, it prints a list of all tasks that were +// still running when the timeout was reached, and their current tree +// of subtasks. +func GracefulShutdown(timeout time.Duration) (err error) { + root.cancel(ErrProgramExiting) + + done := make(chan struct{}) + after := time.After(timeout) + + go func() { + allTasksWg.Wait() + close(done) + }() + + for { + select { + case <-done: + return + case <-after: + b, err := json.Marshal(DebugTaskList()) + if err != nil { + logger.Warn().Err(err).Msg("failed to marshal tasks") + return context.DeadlineExceeded + } + logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size()) + return context.DeadlineExceeded + } + } +} + +// DebugTaskList returns list of all tasks. +// +// The returned string is suitable for printing to the console. +func DebugTaskList() []string { + l := make([]string, 0, allTasks.Size()) + + allTasks.RangeAll(func(t *Task) { + l = append(l, t.name) + }) + + slices.Sort(l) + return l +} diff --git a/internal/utils/atomic.go b/internal/utils/atomic.go deleted file mode 100644 index a2c37b3..0000000 --- a/internal/utils/atomic.go +++ /dev/null @@ -1,30 +0,0 @@ -package utils - -import ( - "encoding/json" - "sync/atomic" -) - -type AtomicValue[T any] struct { - atomic.Value -} - -func (a *AtomicValue[T]) Load() T { - return a.Value.Load().(T) -} - -func (a *AtomicValue[T]) Store(v T) { - a.Value.Store(v) -} - -func (a *AtomicValue[T]) Swap(v T) T { - return a.Value.Swap(v).(T) -} - -func (a *AtomicValue[T]) CompareAndSwap(oldV, newV T) bool { - return a.Value.CompareAndSwap(oldV, newV) -} - -func (a *AtomicValue[T]) MarshalJSON() ([]byte, error) { - return json.Marshal(a.Load()) -} diff --git a/internal/utils/atomic/atomic_value.go b/internal/utils/atomic/atomic_value.go new file mode 100644 index 0000000..7f9bfd2 --- /dev/null +++ b/internal/utils/atomic/atomic_value.go @@ -0,0 +1,30 @@ +package atomic + +import ( + "encoding/json" + "sync/atomic" +) + +type Value[T any] struct { + atomic.Value +} + +func (a *Value[T]) Load() T { + return a.Value.Load().(T) +} + +func (a *Value[T]) Store(v T) { + a.Value.Store(v) +} + +func (a *Value[T]) Swap(v T) T { + return a.Value.Swap(v).(T) +} + +func (a *Value[T]) CompareAndSwap(oldV, newV T) bool { + return a.Value.CompareAndSwap(oldV, newV) +} + +func (a *Value[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(a.Load()) +} diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index bac733a..c6d0654 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -152,9 +152,10 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error { return m.CollectErrors(do) } - errs := make([]error, 0) - mu := sync.Mutex{} - wg := sync.WaitGroup{} + var errs []error + var mu sync.Mutex + var wg sync.WaitGroup + m.Range(func(k KT, v VT) bool { wg.Add(1) go func() { @@ -171,24 +172,6 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error { return errs } -// RemoveAll removes all key-value pairs from the map where the value matches the given criteria. -// -// Parameters: -// -// criteria: function to determine whether a value should be removed -// -// Returns: -// -// nothing -func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) { - m.Range(func(k KT, v VT) bool { - if criteria(v) { - m.Delete(k) - } - return true - }) -} - func (m Map[KT, VT]) Has(k KT) bool { _, ok := m.Load(k) return ok diff --git a/internal/utils/functional/set.go b/internal/utils/functional/set.go index f34a9e5..c238ae8 100644 --- a/internal/utils/functional/set.go +++ b/internal/utils/functional/set.go @@ -61,3 +61,7 @@ func (set Set[T]) RangeAllParallel(f func(T)) { func (set Set[T]) Size() int { return set.m.Size() } + +func (set Set[T]) IsEmpty() bool { + return set.m == nil || set.m.Size() == 0 +} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index a443c97..c19e5b8 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -20,10 +20,17 @@ func IgnoreError[Result any](r Result, _ error) Result { return r } +func fmtError(err error) string { + if err == nil { + return "" + } + return ansi.StripANSI(err.Error()) +} + func ExpectNoError(t *testing.T, err error) { t.Helper() - if err != nil && !reflect.ValueOf(err).IsNil() { - t.Errorf("expected err=nil, got %s", ansi.StripANSI(err.Error())) + if err != nil { + t.Errorf("expected err=nil, got %s", fmtError(err)) t.FailNow() } } @@ -31,7 +38,7 @@ func ExpectNoError(t *testing.T, err error) { func ExpectError(t *testing.T, expected error, err error) { t.Helper() if !errors.Is(err, expected) { - t.Errorf("expected err %s, got %s", expected, ansi.StripANSI(err.Error())) + t.Errorf("expected err %s, got %s", expected, fmtError(err)) t.FailNow() } } @@ -39,7 +46,7 @@ func ExpectError(t *testing.T, expected error, err error) { func ExpectError2(t *testing.T, input any, expected error, err error) { t.Helper() if !errors.Is(err, expected) { - t.Errorf("%v: expected err %s, got %s", input, expected, ansi.StripANSI(err.Error())) + t.Errorf("%v: expected err %s, got %s", input, expected, fmtError(err)) t.FailNow() } } @@ -48,7 +55,7 @@ func ExpectErrorT[T error](t *testing.T, err error) { t.Helper() var errAs T if !errors.As(err, &errAs) { - t.Errorf("expected err %T, got %s", errAs, ansi.StripANSI(err.Error())) + t.Errorf("expected err %T, got %s", errAs, fmtError(err)) t.FailNow() } } diff --git a/internal/watcher/config_file_watcher.go b/internal/watcher/config_file_watcher.go index 31087ea..fc0ccfb 100644 --- a/internal/watcher/config_file_watcher.go +++ b/internal/watcher/config_file_watcher.go @@ -16,8 +16,10 @@ var ( func NewConfigFileWatcher(filename string) Watcher { configDirWatcherMu.Lock() defer configDirWatcherMu.Unlock() + if configDirWatcher == nil { - configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath) + t := task.RootTask("config_dir_watcher", false) + configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath) } return configDirWatcher.Add(filename) } diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index 4fd0648..af8bceb 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -37,7 +37,7 @@ type DirWatcher struct { // // Note that the returned DirWatcher is not ready to use until the goroutine // started by NewDirectoryWatcher has finished. -func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher { +func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() if err != nil { @@ -56,7 +56,7 @@ func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher { fwMap: F.NewMapOf[string, *fileWatcher](), eventCh: make(chan Event), errCh: make(chan E.Error), - task: callerSubtask, + task: parent.Subtask("dir_watcher(" + dirPath + ")"), } go helper.start() return helper @@ -80,17 +80,19 @@ func (h *DirWatcher) Add(relPath string) Watcher { eventCh: make(chan Event), errCh: make(chan E.Error), } - h.task.OnFinished("close file watcher for "+relPath, func() { - close(s.eventCh) - close(s.errCh) - }) h.fwMap.Store(relPath, s) return s } +func (h *DirWatcher) cleanup() { + h.w.Close() + close(h.eventCh) + close(h.errCh) + h.task.Finish(nil) +} + func (h *DirWatcher) start() { - defer close(h.eventCh) - defer h.w.Close() + defer h.cleanup() for { select { diff --git a/internal/watcher/events/event_queue.go b/internal/watcher/events/event_queue.go index 8ed7be3..563c8cd 100644 --- a/internal/watcher/events/event_queue.go +++ b/internal/watcher/events/event_queue.go @@ -1,6 +1,7 @@ package events import ( + "runtime/debug" "time" "github.com/yusing/go-proxy/internal/common" @@ -17,7 +18,7 @@ type ( onFlush OnFlushFunc onError OnErrorFunc } - OnFlushFunc = func(flushTask *task.Task, events []Event) + OnFlushFunc = func(events []Event) OnErrorFunc = func(err E.Error) ) @@ -38,9 +39,9 @@ const eventQueueCapacity = 10 // but the onFlush function can return earlier (e.g. run in another goroutine). // // If task is canceled before the flushInterval is reached, the events in queue will be discarded. -func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { +func NewEventQueue(queueTask *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { return &EventQueue{ - task: parent.Subtask("event queue"), + task: queueTask, queue: make([]Event, 0, eventQueueCapacity), ticker: time.NewTicker(flushInterval), flushInterval: flushInterval, @@ -50,19 +51,20 @@ func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlu } func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { - if common.IsProduction { - origOnFlush := e.onFlush - // recover panic in onFlush when in production mode - e.onFlush = func(flushTask *task.Task, events []Event) { - defer func() { - if err := recover(); err != nil { - e.onError(E.New("recovered panic in onFlush"). - Withf("%v", err). - Subject(e.task.Parent().String())) + origOnFlush := e.onFlush + // recover panic in onFlush when in production mode + e.onFlush = func(events []Event) { + defer func() { + if err := recover(); err != nil { + e.onError(E.New("recovered panic in onFlush"). + Withf("%v", err). + Subject(e.task.Name())) + if common.IsDebug { + panic(string(debug.Stack())) } - }() - origOnFlush(flushTask, events) - } + } + }() + origOnFlush(events) } go func() { @@ -75,19 +77,24 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { return case <-e.ticker.C: if len(e.queue) > 0 { - flushTask := e.task.Subtask("flush events") - queue := e.queue - e.queue = make([]Event, 0, eventQueueCapacity) - go e.onFlush(flushTask, queue) - flushTask.Wait() + // clone -> clear -> flush + queue := make([]Event, len(e.queue)) + copy(queue, e.queue) + + e.queue = e.queue[:0] + + e.onFlush(queue) } e.ticker.Reset(e.flushInterval) case event, ok := <-eventCh: - e.queue = append(e.queue, event) if !ok { return } - case err := <-errCh: + e.queue = append(e.queue, event) + case err, ok := <-errCh: + if !ok { + return + } if err != nil { e.onError(err) } @@ -95,10 +102,3 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { } }() } - -// Wait waits for all events to be flushed and the task to finish. -// -// It is safe to call this method multiple times. -func (e *EventQueue) Wait() { - e.task.Wait() -} diff --git a/internal/watcher/health/monitor/monitor.go b/internal/watcher/health/monitor/monitor.go index 3e41a58..c23a476 100644 --- a/internal/watcher/health/monitor/monitor.go +++ b/internal/watcher/health/monitor/monitor.go @@ -13,7 +13,7 @@ import ( "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/task" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -23,9 +23,9 @@ type ( monitor struct { service string config *health.HealthCheckConfig - url U.AtomicValue[types.URL] + url atomic.Value[types.URL] - status U.AtomicValue[health.Status] + status atomic.Value[health.Status] lastResult *health.HealthCheckResult lastSeen time.Time @@ -59,10 +59,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance } // Start implements task.TaskStarter. -func (mon *monitor) Start(routeSubtask *task.Task) E.Error { - mon.service = routeSubtask.Parent().Name() - mon.task = routeSubtask - +func (mon *monitor) Start(parent task.Parent) E.Error { if mon.config.Interval <= 0 { return E.From(ErrNegativeInterval) } @@ -71,6 +68,9 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error { mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service)) } + mon.service = parent.Name() + mon.task = parent.Subtask("health_monitor") + go func() { logger := logging.With().Str("name", mon.service).Logger() @@ -78,10 +78,10 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error { if mon.status.Load() != health.StatusError { mon.status.Store(health.StatusUnknown) } - mon.task.Finish(nil) if mon.metric != nil { mon.metric.Reset() } + mon.task.Finish(nil) }() if err := mon.checkUpdateHealth(); err != nil { @@ -108,6 +108,11 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error { return nil } +// Task implements task.TaskStarter. +func (mon *monitor) Task() *task.Task { + return mon.task +} + // Finish implements task.TaskFinisher. func (mon *monitor) Finish(reason any) { mon.task.Finish(reason)