From 2628d9e8a85920d9aff274d1ec17b854afd3f0ac Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 28 May 2025 22:07:13 +0800 Subject: [PATCH] fix(task): refactor task module and fix reload stuck/error, fix some logic --- internal/config/config.go | 4 +- internal/docker/client.go | 16 +- internal/idlewatcher/watcher.go | 7 +- internal/net/gphttp/server/error.go | 3 +- internal/net/gphttp/server/server.go | 23 +- internal/route/provider/provider.go | 7 +- internal/task/debug.go | 58 +++-- internal/task/impl.go | 237 --------------------- internal/task/task.go | 302 ++++++++++++++++----------- internal/task/task_debug.go | 6 - internal/task/task_prod.go | 4 - internal/task/task_test.go | 33 ++- internal/task/utils.go | 66 ++++-- internal/task/with.go | 48 +++++ 14 files changed, 371 insertions(+), 443 deletions(-) delete mode 100644 internal/task/impl.go create mode 100644 internal/task/with.go diff --git a/internal/config/config.go b/internal/config/config.go index 02cac28..75d771f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -117,13 +117,13 @@ func Reload() gperr.Error { newCfg := newConfig() err := newCfg.load() if err != nil { - newCfg.task.Finish(err) + newCfg.task.FinishAndWait(err) return gperr.New(ansi.Warning("using last config")).With(err) } // cancel all current subtasks -> wait // -> replace config -> start new subtasks - config.GetInstance().(*Config).Task().Finish("config changed") + config.GetInstance().(*Config).Task().FinishAndWait("config changed") newCfg.Start(StartAllServers) config.SetInstance(newCfg) return nil diff --git a/internal/docker/client.go b/internal/docker/client.go index e9395f7..d2a44cc 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -58,20 +58,16 @@ func initClientCleaner() { case <-ticker.C: closeTimedOutClients() case <-cleaner.Context().Done(): + clientMapMu.Lock() + for _, c := range clientMap { + delete(clientMap, c.Key()) + c.Client.Close() + } + clientMapMu.Unlock() return } } }() - - task.OnProgramExit("docker_clients_cleanup", func() { - clientMapMu.Lock() - defer clientMapMu.Unlock() - - for _, c := range clientMap { - delete(clientMap, c.Key()) - c.Client.Close() - } - }) } func closeTimedOutClients() { diff --git a/internal/idlewatcher/watcher.go b/internal/idlewatcher/watcher.go index 54682ba..dd7ac13 100644 --- a/internal/idlewatcher/watcher.go +++ b/internal/idlewatcher/watcher.go @@ -92,7 +92,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) { // same address, likely two routes from the same container return w, nil } - w.task.Finish(causeReload) + w.task.FinishAndWait(causeReload) } watcherMapMu.RUnlock() @@ -156,14 +156,15 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) { w.task = parent.Subtask("idlewatcher."+r.Name(), true) watcherMapMu.Lock() - defer watcherMapMu.Unlock() watcherMap[key] = w + watcherMapMu.Unlock() + go func() { cause := w.watchUntilDestroy() if errors.Is(cause, causeContainerDestroy) || errors.Is(cause, task.ErrProgramExiting) { watcherMapMu.Lock() - defer watcherMapMu.Unlock() delete(watcherMap, key) + watcherMapMu.Unlock() w.l.Info().Msg("idlewatcher stopped") } else if !errors.Is(cause, causeReload) { gperr.LogError("idlewatcher stopped unexpectedly", cause, &w.l) diff --git a/internal/net/gphttp/server/error.go b/internal/net/gphttp/server/error.go index 5a20a03..5fd8d5c 100644 --- a/internal/net/gphttp/server/error.go +++ b/internal/net/gphttp/server/error.go @@ -3,6 +3,7 @@ package server import ( "context" "errors" + "net" "net/http" "github.com/rs/zerolog" @@ -10,7 +11,7 @@ import ( func convertError(err error) error { switch { - case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): + case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled), errors.Is(err, net.ErrClosed): return nil default: return err diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go index 3977ef6..5e9a391 100644 --- a/internal/net/gphttp/server/server.go +++ b/internal/net/gphttp/server/server.go @@ -3,6 +3,8 @@ package server import ( "context" "crypto/tls" + "errors" + "io" "net" "net/http" "time" @@ -105,7 +107,7 @@ func (s *Server) Start(parent task.Parent) { } Start(subtask, h3, s.acl, &s.l) if s.http != nil { - s.http.Handler = advertiseHTTP3(s.http.Handler, h3) + s.http.Handler = advertiseHTTP3(s.http.Handler, h3) } // s.https is not nil (checked above) s.https.Handler = advertiseHTTP3(s.https.Handler, h3) @@ -115,7 +117,7 @@ func (s *Server) Start(parent task.Parent) { Start(subtask, s.https, s.acl, &s.l) } -func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, logger *zerolog.Logger) { +func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, logger *zerolog.Logger) (port int) { if srv == nil { return } @@ -138,6 +140,7 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l HandleError(logger, err, "failed to listen on port") return } + port = l.Addr().(*net.TCPAddr).Port if srv.TLSConfig != nil { l = tls.NewListener(l, srv.TLSConfig) } @@ -145,32 +148,36 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l l = acl.WrapTCP(l) } serveFunc = getServeFunc(l, srv.Serve) + task.OnCancel("stop", func() { + stop(srv, l, logger) + }) case *http3.Server: l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr) if err != nil { HandleError(logger, err, "failed to listen on port") return } + port = l.LocalAddr().(*net.UDPAddr).Port if acl != nil { l = acl.WrapUDP(l) } serveFunc = getServeFunc(l, srv.Serve) + task.OnCancel("stop", func() { + stop(srv, l, logger) + }) } - task.OnCancel("stop", func() { - stop(srv, logger) - }) logStarted(srv, logger) go func() { err := convertError(serveFunc()) if err != nil { - HandleError(logger, err, "failed to serve "+proto+" server") + HandleError(logger, err, "failed to serve "+proto+" server") } task.Finish(err) }() return port } -func stop[Server httpServer](srv Server, logger *zerolog.Logger) { +func stop[Server httpServer](srv Server, l io.Closer, logger *zerolog.Logger) { if srv == nil { return } @@ -180,7 +187,7 @@ func stop[Server httpServer](srv Server, logger *zerolog.Logger) { ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second) defer cancel() - if err := convertError(srv.Shutdown(ctx)); err != nil { + if err := convertError(errors.Join(srv.Shutdown(ctx), l.Close())); err != nil { HandleError(logger, err, "failed to shutdown "+proto+" server") } else { logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped") diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index ffdf049..e3a6c21 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -100,18 +100,21 @@ func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error { func (p *Provider) Start(parent task.Parent) gperr.Error { t := parent.Subtask("provider."+p.String(), false) + routesTask := t.Subtask("routes", false) errs := gperr.NewBuilder("routes error") for _, r := range p.routes { - errs.Add(p.startRoute(t, r)) + errs.Add(p.startRoute(routesTask, r)) } eventQueue := events.NewEventQueue( t.Subtask("event_queue", false), providerEventFlushInterval, func(events []events.Event) { + routesTask.FinishAndWait("reload routes") + routesTask = t.Subtask("routes", false) handler := p.newEventHandler() // routes' lifetime should follow the provider's lifetime - handler.Handle(t, events) + handler.Handle(routesTask, events) handler.Log() }, func(err gperr.Error) { diff --git a/internal/task/debug.go b/internal/task/debug.go index 4c29f1e..e1849dd 100644 --- a/internal/task/debug.go +++ b/internal/task/debug.go @@ -8,45 +8,59 @@ import ( ) // debug only. -func (t *Task) listStuckedCallbacks() []string { - t.mu.Lock() - defer t.mu.Unlock() - callbacks := make([]string, 0, len(t.callbacksOnFinish)) - for c := range t.callbacksOnFinish { - callbacks = append(callbacks, c.about) +func listStuckedCallbacks(t *Task) []string { + callbacks := make([]string, 0) + if t.onFinish != nil { + for c := range t.onFinish.Range { + callbacks = append(callbacks, c.about) + } + } + if t.onCancel != nil { + for c := range t.onCancel.Range { + callbacks = append(callbacks, c.about) + } + } + if t.children != nil { + for c := range t.children.Range { + callbacks = append(callbacks, listStuckedCallbacks(c)...) + } } return callbacks } // debug only. -func (t *Task) listStuckedChildren() []string { - t.mu.Lock() - defer t.mu.Unlock() - children := make([]string, 0, len(t.children)) - for c := range t.children { - if c.isFinished() { - continue - } - children = append(children, c.String()) - if len(c.children) > 0 { - children = append(children, c.listStuckedChildren()...) +func listStuckedChildren(t *Task) []string { + if t.children != nil { + children := make([]string, 0) + for c := range t.children.Range { + children = append(children, c.String()) + children = append(children, listStuckedCallbacks(c)...) } + return children } - return children + return nil } func (t *Task) reportStucked() { - callbacks := t.listStuckedCallbacks() - children := t.listStuckedChildren() + callbacks := listStuckedCallbacks(t) + children := listStuckedChildren(t) if len(callbacks) == 0 && len(children) == 0 { return } fmtOutput := gperr.NewBuilder(fmt.Sprintf("%s stucked callbacks: %d, stucked children: %d", t.String(), len(callbacks), len(children))) if len(callbacks) > 0 { - fmtOutput.Add(gperr.New("callbacks").With(gperr.Multiline().AddLinesString(callbacks...))) + callbackBuilder := gperr.NewBuilder("callbacks") + for _, c := range callbacks { + callbackBuilder.Adds(c) + } + fmtOutput.Add(callbackBuilder.Error()) } if len(children) > 0 { - fmtOutput.Add(gperr.New("children").With(gperr.Multiline().AddLinesString(children...))) + childrenBuilder := gperr.NewBuilder("children") + for _, c := range children { + childrenBuilder.Adds(c) + } + fmtOutput.Add(childrenBuilder.Error()) } log.Warn().Msg(fmtOutput.String()) } diff --git a/internal/task/impl.go b/internal/task/impl.go deleted file mode 100644 index 9c5faf7..0000000 --- a/internal/task/impl.go +++ /dev/null @@ -1,237 +0,0 @@ -package task - -import ( - "context" - "errors" - "fmt" - "time" - - "github.com/rs/zerolog/log" -) - -var ( - taskPool = make(chan *Task, 100) - - voidTask = &Task{ctx: context.Background()} - root = newRoot() - - cancelCtx context.Context -) - -func init() { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - cancelCtx = ctx //nolint:fatcontext - - voidTask.parent = root -} - -func testCleanup() { - root = newRoot() -} - -func newRoot() *Task { - return newTask("root", voidTask, true) -} - -func noCancel(error) { - // do nothing -} - -//go:inline -func newTask(name string, parent *Task, needFinish bool) *Task { - var t *Task - select { - case t = <-taskPool: - t.finished.Store(false) - default: - t = &Task{} - } - t.name = name - t.parent = parent - if needFinish { - t.ctx, t.cancel = context.WithCancelCause(parent.ctx) - } else { - t.ctx, t.cancel = parent.ctx, noCancel - } - return t -} - -//go:inline -func (t *Task) needFinish() bool { - return t.ctx != t.parent.ctx -} - -//go:inline -func (t *Task) isCanceled() bool { - return t.cancel == nil -} - -//go:inline -func putTask(t *Task) { - select { - case taskPool <- t: - default: - return - } -} - -//go:inline -func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) { - if !t.needFinish() { - if waitSubTasks { - t.parent.addCallback(about, func() { - if !t.waitFinish(taskTimeout) { - t.reportStucked() - } - fn() - }, false) - } else { - t.parent.addCallback(about, fn, false) - } - return - } - - if !waitSubTasks { - t.mu.Lock() - defer t.mu.Unlock() - if t.callbacksOnCancel == nil { - t.callbacksOnCancel = make(callbacksSet) - go func() { - <-t.ctx.Done() - for c := range t.callbacksOnCancel { - go func() { - invokeWithRecover(c) - t.mu.Lock() - delete(t.callbacksOnCancel, c) - t.mu.Unlock() - }() - } - }() - } - t.callbacksOnCancel[&Callback{fn: fn, about: about}] = struct{}{} - return - } - - t.mu.Lock() - defer t.mu.Unlock() - - if t.isCanceled() { - log.Panic(). - Str("task", t.String()). - Str("callback", about). - Msg("callback added to canceled task") - return - } - - if t.callbacksOnFinish == nil { - t.callbacksOnFinish = make(callbacksSet) - } - t.callbacksOnFinish[&Callback{ - fn: fn, - about: about, - }] = struct{}{} -} - -//go:inline -func (t *Task) addChild(child *Task) { - t.mu.Lock() - defer t.mu.Unlock() - - if t.isCanceled() { - log.Panic(). - Str("task", t.String()). - Str("child", child.Name()). - Msg("child added to canceled task") - return - } - - if t.children == nil { - t.children = make(childrenSet) - } - t.children[child] = struct{}{} -} - -//go:inline -func (t *Task) removeChild(child *Task) { - t.mu.Lock() - defer t.mu.Unlock() - delete(t.children, child) -} - -func (t *Task) runOnFinishCallbacks() { - if len(t.callbacksOnFinish) == 0 { - return - } - - for c := range t.callbacksOnFinish { - go func() { - invokeWithRecover(c) - t.mu.Lock() - delete(t.callbacksOnFinish, c) - t.mu.Unlock() - }() - } -} - -func (t *Task) waitFinish(timeout time.Duration) bool { - // return directly if already finished - if t.isFinished() { - return true - } - - t.mu.Lock() - children, callbacksOnCancel, callbacksOnFinish := t.children, t.callbacksOnCancel, t.callbacksOnFinish - t.mu.Unlock() - - ok := true - if len(children) != 0 { - ok = waitEmpty(children, timeout) - } - - if len(callbacksOnCancel) != 0 { - ok = ok && waitEmpty(callbacksOnCancel, timeout) - } - - if len(callbacksOnFinish) != 0 { - ok = ok && waitEmpty(callbacksOnFinish, timeout) - } - - return ok -} - -//go:inline -func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool { - if len(set) == 0 { - return true - } - - timer := time.NewTimer(timeout) - defer timer.Stop() - - for { - if len(set) == 0 { - return true - } - select { - case <-timer.C: - return false - default: - time.Sleep(100 * time.Millisecond) - } - } -} - -//go:inline -func fmtCause(cause any) error { - switch cause := cause.(type) { - case nil: - return nil - case error: - return cause - case string: - return errors.New(cause) - default: - return fmt.Errorf("%v", cause) - } -} diff --git a/internal/task/task.go b/internal/task/task.go index abb45cf..40e060a 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -1,15 +1,10 @@ -// This file has the abstract logic of the task system. -// -// The implementation of the task system is in the impl.go file. package task import ( "context" "sync" - "sync/atomic" "time" - "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" ) @@ -35,21 +30,23 @@ type ( // // Use Task.Finish to stop all subtasks of the Task. Task struct { - name string + parent *Task + name string + ctx context.Context + cancel context.CancelCauseFunc + done chan struct{} + finishCalled bool + onCancel *withWg[*Callback] + onFinish *withWg[*Callback] + children *withWg[*Task] - parent *Task - children childrenSet - callbacksOnFinish callbacksSet - callbacksOnCancel callbacksSet - - ctx context.Context - cancel context.CancelCauseFunc - - finished atomic.Bool - mu sync.Mutex + mu sync.Mutex } Parent interface { Context() context.Context + // Subtask returns a new subtask with the given name, derived from the parent's context. + // + // This should not be called after Finish is called on the task or its parent task. Subtask(name string, needFinish bool) *Task Name() string Finish(reason any) @@ -57,124 +54,193 @@ type ( } ) -type ( - childrenSet = map[*Task]struct{} - callbacksSet = map[*Callback]struct{} -) - const taskTimeout = 3 * time.Second func (t *Task) Context() context.Context { return t.ctx } -// FinishCause returns the reason / error that caused the task to be finished. -func (t *Task) FinishCause() error { - return context.Cause(t.ctx) -} - -// OnFinished calls fn when the task is canceled and all subtasks are finished. -// -// It should not be called after Finish is called. -func (t *Task) OnFinished(about string, fn func()) { - t.addCallback(about, fn, true) -} - -// OnCancel calls fn when the task is canceled. -// -// It should not be called after Finish is called. -func (t *Task) OnCancel(about string, fn func()) { - t.addCallback(about, fn, false) -} - -// Finish cancel all subtasks and wait for them to finish, -// then marks the task as finished, with the given reason (if any). -func (t *Task) Finish(reason any) { - t.mu.Lock() - if t.isCanceled() { - t.mu.Unlock() - return - } - t.cancel(fmtCause(reason)) - t.ctx, t.cancel = cancelCtx, nil - - t.mu.Unlock() - - t.finishAndWait() - t.finished.Store(true) -} - -func (t *Task) finishAndWait() { - ok := true - - if !waitEmpty(t.children, taskTimeout) { - t.reportStucked() - ok = false - } - t.runOnFinishCallbacks() - - if !t.waitFinish(taskTimeout) { - t.reportStucked() - ok = false - } - // clear anyway - clear(t.children) - clear(t.callbacksOnFinish) - - if t != root && t.needFinish() { - t.parent.removeChild(t) - } - logFinished(t) - - if ok { - putTask(t) - } -} - -func (t *Task) isFinished() bool { - return t.finished.Load() -} - -// Subtask returns a new subtask with the given name, derived from the parent's context. -// -// This should not be called after Finish is called on the task or its parent task. -func (t *Task) Subtask(name string, needFinish bool) *Task { - panicIfFinished(t, "Subtask is called") - - child := newTask(name, t, needFinish) - - if needFinish { - t.addChild(child) - } - - logStarted(child) - return child -} - -// Name returns the name of the task without parent names. func (t *Task) Name() string { return t.name } // String returns the full name of the task. func (t *Task) String() string { - if t.parent != root { - return t.parent.String() + "." + t.name - } - return t.name + return t.fullName() } // MarshalText implements encoding.TextMarshaler. func (t *Task) MarshalText() ([]byte, error) { - return []byte(t.String()), nil + return []byte(t.fullName()), nil } -func invokeWithRecover(cb *Callback) { - defer func() { - if err := recover(); err != nil { - log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic") - panicWithDebugStack() - } - }() - cb.fn() +// Finish marks the task as finished, with the given reason (if any). +func (t *Task) Finish(reason any) { + t.finish(reason, false) +} + +// FinishCause returns the reason / error that caused the task to be finished. +func (t *Task) FinishCause() error { + return context.Cause(t.ctx) +} + +// FinishAndWait cancel all subtasks and wait for them to finish, +// then marks the task as finished, with the given reason (if any). +func (t *Task) FinishAndWait(reason any) { + t.finish(reason, true) +} + +// OnFinished calls fn when the task is canceled and all subtasks are finished. +// +// It should not be called after Finish is called. +func (t *Task) OnFinished(about string, fn func()) { + if !t.needFinish() { + t.OnCancel(about, fn) + return + } + + t.mu.Lock() + if t.onFinish == nil { + t.onFinish = newWithWg[*Callback]() + t.mu.Unlock() + + go func() { + <-t.ctx.Done() + <-t.done + for cb := range t.onFinish.Range { + go func(cb *Callback) { + invokeWithRecover(cb) + t.onFinish.Delete(cb) + }(cb) + } + }() + } else { + t.mu.Unlock() + } + + t.onFinish.Add(&Callback{fn: fn, about: about}) +} + +// OnCancel calls fn when the task is canceled. +// +// It should not be called after Finish is called. +func (t *Task) OnCancel(about string, fn func()) { + t.mu.Lock() + if t.onCancel == nil { + t.onCancel = newWithWg[*Callback]() + t.mu.Unlock() + + go func() { + <-t.ctx.Done() + for cb := range t.onCancel.Range { + go func(cb *Callback) { + invokeWithRecover(cb) + t.onCancel.Delete(cb) + }(cb) + } + }() + } else { + t.mu.Unlock() + } + + t.onCancel.Add(&Callback{fn: fn, about: about}) +} + +// Subtask returns a new subtask with the given name, derived from the parent's context. +// +// This should not be called after Finish is called on the task or its parent task. +func (t *Task) Subtask(name string, needFinish bool) *Task { + t.mu.Lock() + if t.children == nil { + t.children = newWithWg[*Task]() + t.mu.Unlock() + } else { + t.mu.Unlock() + } + + child := &Task{ + name: name, + parent: t, + } + + t.children.Add(child) + + child.ctx, child.cancel = context.WithCancelCause(t.ctx) + + if needFinish { + child.done = make(chan struct{}) + } else { + child.done = closedCh + go func() { + <-child.ctx.Done() + child.Finish(t.FinishCause()) + }() + } + + logStarted(child) + return child +} + +func (t *Task) finish(reason any, wait bool) { + t.mu.Lock() + if t.finishCalled { + t.mu.Unlock() + // wait but not report stucked (again) + t.waitFinish(taskTimeout) + return + } + + t.finishCalled = true + t.mu.Unlock() + + if t.needFinish() { + close(t.done) + } + + t.cancel(fmtCause(reason)) + if wait && !t.waitFinish(taskTimeout) { + t.reportStucked() + } + if t != root { + t.parent.children.Delete(t) + } + logFinished(t) +} + +func (t *Task) waitFinish(timeout time.Duration) bool { + if t.children == nil && t.onCancel == nil && t.onFinish == nil { + return true + } + done := make(chan struct{}) + go func() { + if t.children != nil { + t.children.Wait() + } + if t.onCancel != nil { + t.onCancel.Wait() + } + if t.onFinish != nil { + t.onFinish.Wait() + } + <-t.done + close(done) + }() + timeoutCh := time.After(timeout) + select { + case <-done: + return true + case <-timeoutCh: + return false + } +} + +func (t *Task) fullName() string { + if t.parent == root { + return t.name + } + return t.parent.fullName() + "." + t.name +} + +func (t *Task) needFinish() bool { + return t.done != closedCh } diff --git a/internal/task/task_debug.go b/internal/task/task_debug.go index 80a03dc..873853e 100644 --- a/internal/task/task_debug.go +++ b/internal/task/task_debug.go @@ -12,12 +12,6 @@ func panicWithDebugStack() { panic(string(debug.Stack())) } -func panicIfFinished(t *Task, reason string) { - if t.isFinished() { - log.Panic().Msg("task " + t.String() + " is finished but " + reason) - } -} - func logStarted(t *Task) { log.Info().Msg("task " + t.String() + " started") } diff --git a/internal/task/task_prod.go b/internal/task/task_prod.go index 4a07d2f..90a18a8 100644 --- a/internal/task/task_prod.go +++ b/internal/task/task_prod.go @@ -6,10 +6,6 @@ func panicWithDebugStack() { // do nothing } -func panicIfFinished(t *Task, reason string) { - // do nothing -} - func logStarted(t *Task) { // do nothing } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 0341a02..091ef8e 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - . "github.com/yusing/go-proxy/internal/utils/testing" + expect "github.com/yusing/go-proxy/internal/utils/testing" ) func testTask() *Task { @@ -35,7 +35,7 @@ func TestChildTaskCancellation(t *testing.T) { select { case <-child.Context().Done(): - ExpectError(t, context.Canceled, child.Context().Err()) + expect.ErrorIs(t, context.Canceled, child.Context().Err()) default: t.Fatal("subTask context was not canceled as expected") } @@ -80,10 +80,10 @@ func TestTaskOnCancelOnFinished(t *testing.T) { shouldTrueOnFinish = true }) - ExpectFalse(t, shouldTrueOnFinish) + expect.False(t, shouldTrueOnFinish) task.Finish(nil) - ExpectTrue(t, shouldTrueOnCancel) - ExpectTrue(t, shouldTrueOnFinish) + expect.True(t, shouldTrueOnCancel) + expect.True(t, shouldTrueOnFinish) } func TestCommonFlowWithGracefulShutdown(t *testing.T) { @@ -108,29 +108,28 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) { } }() - ExpectNoError(t, gracefulShutdown(1*time.Second)) - time.Sleep(100 * time.Millisecond) - ExpectTrue(t, finished) + expect.NoError(t, gracefulShutdown(1*time.Second)) + expect.True(t, finished) - ExpectTrue(t, root.waitFinish(1*time.Second)) - ExpectError(t, context.Canceled, context.Cause(task.Context())) - ExpectError(t, ErrProgramExiting, task.Context().Err()) - ExpectError(t, ErrProgramExiting, task.FinishCause()) + expect.ErrorIs(t, ErrProgramExiting, context.Cause(task.Context())) + expect.ErrorIs(t, context.Canceled, task.Context().Err()) + expect.ErrorIs(t, ErrProgramExiting, task.FinishCause()) } func TestTimeoutOnGracefulShutdown(t *testing.T) { t.Cleanup(testCleanup) _ = testTask() - ExpectError(t, context.DeadlineExceeded, gracefulShutdown(time.Millisecond)) + expect.ErrorIs(t, context.DeadlineExceeded, gracefulShutdown(time.Millisecond)) } func TestFinishMultipleCalls(t *testing.T) { t.Cleanup(testCleanup) task := testTask() var wg sync.WaitGroup - wg.Add(5) - for range 5 { + n := 20 + wg.Add(n) + for range n { go func() { defer wg.Done() task.Finish(nil) @@ -157,8 +156,8 @@ func BenchmarkTasksNeedFinish(b *testing.B) { func BenchmarkContextWithCancel(b *testing.B) { for b.Loop() { - task, taskCancel := context.WithCancel(b.Context()) - taskCancel() + task, taskCancel := context.WithCancelCause(b.Context()) + taskCancel(nil) <-task.Done() } } diff --git a/internal/task/utils.go b/internal/task/utils.go index 7994548..5b63fa7 100644 --- a/internal/task/utils.go +++ b/internal/task/utils.go @@ -3,6 +3,7 @@ package task import ( "context" "errors" + "fmt" "os" "os/signal" "syscall" @@ -13,6 +14,31 @@ import ( var ErrProgramExiting = errors.New("program exiting") +var root *Task + +var closedCh = make(chan struct{}) + +func init() { + close(closedCh) + initRoot() +} + +func initRoot() { + ctx, cancel := context.WithCancelCause(context.Background()) + root = &Task{ + name: "root", + ctx: ctx, + cancel: cancel, + done: closedCh, + } + root.parent = root +} + +func testCleanup() { + root.cancel(nil) + initRoot() +} + // RootTask returns a new Task with the given name, derived from the root context. // //go:inline @@ -29,7 +55,7 @@ func RootContextCanceled() <-chan struct{} { } func OnProgramExit(about string, fn func()) { - root.OnFinished(about, fn) + root.OnCancel(about, fn) } // WaitExit waits for a signal to shutdown the program, and then waits for all tasks to finish, up to the given timeout. @@ -59,19 +85,33 @@ func WaitExit(shutdownTimeout int) { // still running when the timeout was reached, and their current tree // of subtasks. func gracefulShutdown(timeout time.Duration) error { - root.mu.Lock() - if root.isCanceled() { - cause := context.Cause(root.ctx) - root.mu.Unlock() - return cause - } - root.mu.Unlock() - - root.cancel(ErrProgramExiting) - ok := waitEmpty(root.children, timeout) - root.runOnFinishCallbacks() - if !ok || !root.waitFinish(timeout) { + go root.Finish(ErrProgramExiting) + if !root.waitFinish(timeout) { return context.DeadlineExceeded } return nil } + +func invokeWithRecover(cb *Callback) { + defer func() { + if err := recover(); err != nil { + log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic") + panicWithDebugStack() + } + }() + cb.fn() +} + +//go:inline +func fmtCause(cause any) error { + switch cause := cause.(type) { + case nil: + return nil + case error: + return cause + case string: + return errors.New(cause) + default: + return fmt.Errorf("%v", cause) + } +} diff --git a/internal/task/with.go b/internal/task/with.go new file mode 100644 index 0000000..bfb47ad --- /dev/null +++ b/internal/task/with.go @@ -0,0 +1,48 @@ +package task + +import ( + "sync" + + "github.com/puzpuzpuz/xsync/v4" +) + +type withWg[T comparable] struct { + m *xsync.Map[T, struct{}] + wg sync.WaitGroup +} + +func newWithWg[T comparable]() *withWg[T] { + return &withWg[T]{ + m: xsync.NewMap[T, struct{}](), + } +} + +func (w *withWg[T]) Add(ele T) { + w.wg.Add(1) + w.m.Store(ele, struct{}{}) +} + +func (w *withWg[T]) AddWithoutWG(ele T) { + w.m.Store(ele, struct{}{}) +} + +func (w *withWg[T]) Delete(key T) { + w.wg.Done() + w.m.Delete(key) +} + +func (w *withWg[T]) DeleteWithoutWG(key T) { + w.m.Delete(key) +} + +func (w *withWg[T]) Wait() { + w.wg.Wait() +} + +func (w *withWg[T]) Range(yield func(T) bool) { + for ele := range w.m.Range { + if !yield(ele) { + break + } + } +}