diff --git a/internal/docker/client.go b/internal/docker/client.go index 6d15d88..e9395f7 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -47,7 +47,7 @@ const ( ) func initClientCleaner() { - cleaner := task.RootTask("docker_clients_cleaner", false) + cleaner := task.RootTask("docker_clients_cleaner", true) go func() { ticker := time.NewTicker(cleanInterval) defer ticker.Stop() diff --git a/internal/metrics/period/poller.go b/internal/metrics/period/poller.go index 3675258..b881a6c 100644 --- a/internal/metrics/period/poller.go +++ b/internal/metrics/period/poller.go @@ -155,10 +155,11 @@ func (p *Poller[T, AggregateT]) Start() { gatherErrsTicker.Stop() saveTicker.Stop() - if err := p.save(); err != nil { + err := p.save() + if err != nil { l.Err(err).Msg("failed to save metrics data") } - t.Finish(nil) + t.Finish(err) }() l.Debug().Dur("interval", pollInterval).Msg("Starting poller") diff --git a/internal/net/gphttp/server/error.go b/internal/net/gphttp/server/error.go index 807950c..5a20a03 100644 --- a/internal/net/gphttp/server/error.go +++ b/internal/net/gphttp/server/error.go @@ -8,11 +8,15 @@ import ( "github.com/rs/zerolog" ) -func HandleError(logger *zerolog.Logger, err error, msg string) { +func convertError(err error) error { switch { case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): - return + return nil default: - logger.Fatal().Err(err).Msg(msg) + return err } } + +func HandleError(logger *zerolog.Logger, err error, msg string) { + logger.Fatal().Err(err).Msg(msg) +} diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go index bf2e7f6..3977ef6 100644 --- a/internal/net/gphttp/server/server.go +++ b/internal/net/gphttp/server/server.go @@ -104,7 +104,10 @@ func (s *Server) Start(parent task.Parent) { TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig), } Start(subtask, h3, s.acl, &s.l) + if s.http != nil { s.http.Handler = advertiseHTTP3(s.http.Handler, h3) + } + // s.https is not nil (checked above) s.https.Handler = advertiseHTTP3(s.https.Handler, h3) } @@ -120,7 +123,7 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l setDebugLogger(srv, logger) proto := proto(srv) - task := parent.Subtask(proto, false) + task := parent.Subtask(proto, true) var lc net.ListenConfig var serveFunc func() error @@ -158,9 +161,13 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l }) logStarted(srv, logger) go func() { - err := serveFunc() + err := convertError(serveFunc()) + if err != nil { HandleError(logger, err, "failed to serve "+proto+" server") + } + task.Finish(err) }() + return port } func stop[Server httpServer](srv Server, logger *zerolog.Logger) { @@ -173,7 +180,7 @@ func stop[Server httpServer](srv Server, logger *zerolog.Logger) { ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second) defer cancel() - if err := srv.Shutdown(ctx); err != nil { + if err := convertError(srv.Shutdown(ctx)); err != nil { HandleError(logger, err, "failed to shutdown "+proto+" server") } else { logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped") diff --git a/internal/route/stream.go b/internal/route/stream.go index 6be8af8..b9dbf14 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -99,6 +99,11 @@ func (r *StreamRoute) HealthMonitor() health.HealthMonitor { func (r *StreamRoute) acceptConnections() { defer r.task.Finish("listener closed") + go func() { + <-r.task.Context().Done() + r.Close() + }() + for { select { case <-r.task.Context().Done(): diff --git a/internal/task/debug.go b/internal/task/debug.go index 6343cb0..4c29f1e 100644 --- a/internal/task/debug.go +++ b/internal/task/debug.go @@ -1,23 +1,27 @@ package task import ( + "fmt" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" ) // debug only. func (t *Task) listStuckedCallbacks() []string { - callbacks := make([]string, 0, len(t.callbacks)) - for c := range t.callbacks { - if !c.done.Load() { - callbacks = append(callbacks, c.about) - } + t.mu.Lock() + defer t.mu.Unlock() + callbacks := make([]string, 0, len(t.callbacksOnFinish)) + for c := range t.callbacksOnFinish { + callbacks = append(callbacks, c.about) } return callbacks } // debug only. func (t *Task) listStuckedChildren() []string { + t.mu.Lock() + defer t.mu.Unlock() children := make([]string, 0, len(t.children)) for c := range t.children { if c.isFinished() { @@ -34,13 +38,15 @@ func (t *Task) listStuckedChildren() []string { func (t *Task) reportStucked() { callbacks := t.listStuckedCallbacks() children := t.listStuckedChildren() - fmtOutput := gperr.Multiline(). - Addf("stucked callbacks: %d, stucked children: %d", - len(callbacks), len(children), - ). - Addf("callbacks"). - AddLinesString(callbacks...). - Addf("children"). - AddLinesString(children...) - log.Warn().Msg(fmtOutput.Error()) + if len(callbacks) == 0 && len(children) == 0 { + return + } + fmtOutput := gperr.NewBuilder(fmt.Sprintf("%s stucked callbacks: %d, stucked children: %d", t.String(), len(callbacks), len(children))) + if len(callbacks) > 0 { + fmtOutput.Add(gperr.New("callbacks").With(gperr.Multiline().AddLinesString(callbacks...))) + } + if len(children) > 0 { + fmtOutput.Add(gperr.New("children").With(gperr.Multiline().AddLinesString(children...))) + } + log.Warn().Msg(fmtOutput.String()) } diff --git a/internal/task/impl.go b/internal/task/impl.go index 3428f2c..9c5faf7 100644 --- a/internal/task/impl.go +++ b/internal/task/impl.go @@ -4,23 +4,38 @@ import ( "context" "errors" "fmt" - "sync" "time" - _ "unsafe" + + "github.com/rs/zerolog/log" ) var ( taskPool = make(chan *Task, 100) - root = newRoot() + voidTask = &Task{ctx: context.Background()} + root = newRoot() + + cancelCtx context.Context ) +func init() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + cancelCtx = ctx //nolint:fatcontext + + voidTask.parent = root +} + func testCleanup() { root = newRoot() } func newRoot() *Task { - return newTask("root", nil, true) + return newTask("root", voidTask, true) +} + +func noCancel(error) { + // do nothing } //go:inline @@ -28,20 +43,31 @@ func newTask(name string, parent *Task, needFinish bool) *Task { var t *Task select { case t = <-taskPool: + t.finished.Store(false) default: t = &Task{} } t.name = name t.parent = parent if needFinish { - t.canceled = make(chan struct{}) + t.ctx, t.cancel = context.WithCancelCause(parent.ctx) } else { - // it will not be nil, because root task always has a canceled channel - t.canceled = parent.canceled + t.ctx, t.cancel = parent.ctx, noCancel } return t } +//go:inline +func (t *Task) needFinish() bool { + return t.ctx != t.parent.ctx +} + +//go:inline +func (t *Task) isCanceled() bool { + return t.cancel == nil +} + +//go:inline func putTask(t *Task) { select { case taskPool <- t: @@ -50,49 +76,76 @@ func putTask(t *Task) { } } -//go:inline -func (t *Task) setCause(cause error) { - if cause == nil { - t.cause = context.Canceled - } else { - t.cause = cause - } -} - //go:inline func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) { - t.mu.Lock() - if t.cause != nil { - t.mu.Unlock() + if !t.needFinish() { if waitSubTasks { - waitEmpty(t.children, taskTimeout) + t.parent.addCallback(about, func() { + if !t.waitFinish(taskTimeout) { + t.reportStucked() + } + fn() + }, false) + } else { + t.parent.addCallback(about, fn, false) } - fn() return } - defer t.mu.Unlock() - if t.callbacks == nil { - t.callbacks = make(callbacksSet) + if !waitSubTasks { + t.mu.Lock() + defer t.mu.Unlock() + if t.callbacksOnCancel == nil { + t.callbacksOnCancel = make(callbacksSet) + go func() { + <-t.ctx.Done() + for c := range t.callbacksOnCancel { + go func() { + invokeWithRecover(c) + t.mu.Lock() + delete(t.callbacksOnCancel, c) + t.mu.Unlock() + }() + } + }() + } + t.callbacksOnCancel[&Callback{fn: fn, about: about}] = struct{}{} + return } - t.callbacks[&Callback{ - fn: fn, - about: about, - waitChildren: waitSubTasks, + + t.mu.Lock() + defer t.mu.Unlock() + + if t.isCanceled() { + log.Panic(). + Str("task", t.String()). + Str("callback", about). + Msg("callback added to canceled task") + return + } + + if t.callbacksOnFinish == nil { + t.callbacksOnFinish = make(callbacksSet) + } + t.callbacksOnFinish[&Callback{ + fn: fn, + about: about, }] = struct{}{} } //go:inline func (t *Task) addChild(child *Task) { t.mu.Lock() - if t.cause != nil { - t.mu.Unlock() - child.Finish(t.FinishCause()) + defer t.mu.Unlock() + + if t.isCanceled() { + log.Panic(). + Str("task", t.String()). + Str("child", child.Name()). + Msg("child added to canceled task") return } - defer t.mu.Unlock() - if t.children == nil { t.children = make(childrenSet) } @@ -106,67 +159,19 @@ func (t *Task) removeChild(child *Task) { delete(t.children, child) } -func (t *Task) finishChildren() { - t.mu.Lock() - if len(t.children) == 0 { - t.mu.Unlock() +func (t *Task) runOnFinishCallbacks() { + if len(t.callbacksOnFinish) == 0 { return } - var wg sync.WaitGroup - for child := range t.children { - wg.Add(1) + for c := range t.callbacksOnFinish { go func() { - defer wg.Done() - child.Finish(t.cause) + invokeWithRecover(c) + t.mu.Lock() + delete(t.callbacksOnFinish, c) + t.mu.Unlock() }() } - - clear(t.children) - t.mu.Unlock() - wg.Wait() -} - -func (t *Task) runCallbacks() { - t.mu.Lock() - if len(t.callbacks) == 0 { - t.mu.Unlock() - return - } - - var wg sync.WaitGroup - var needWait bool - - // runs callbacks that does not need wait first - for c := range t.callbacks { - if !c.waitChildren { - wg.Add(1) - go func() { - defer wg.Done() - invokeWithRecover(c) - }() - } else { - needWait = true - } - } - - // runs callbacks that need to wait for children - if needWait { - waitEmpty(t.children, taskTimeout) - for c := range t.callbacks { - if c.waitChildren { - wg.Add(1) - go func() { - defer wg.Done() - invokeWithRecover(c) - }() - } - } - } - - clear(t.callbacks) - t.mu.Unlock() - wg.Wait() } func (t *Task) waitFinish(timeout time.Duration) bool { @@ -175,16 +180,24 @@ func (t *Task) waitFinish(timeout time.Duration) bool { return true } - if len(t.children) == 0 && len(t.callbacks) == 0 { - return true + t.mu.Lock() + children, callbacksOnCancel, callbacksOnFinish := t.children, t.callbacksOnCancel, t.callbacksOnFinish + t.mu.Unlock() + + ok := true + if len(children) != 0 { + ok = waitEmpty(children, timeout) } - ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout) - if !ok { - return false + if len(callbacksOnCancel) != 0 { + ok = ok && waitEmpty(callbacksOnCancel, timeout) } - t.finished.Store(true) - return true + + if len(callbacksOnFinish) != 0 { + ok = ok && waitEmpty(callbacksOnFinish, timeout) + } + + return ok } //go:inline @@ -193,8 +206,6 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool { return true } - var sema uint32 - timer := time.NewTimer(timeout) defer timer.Stop() @@ -206,7 +217,7 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool { case <-timer.C: return false default: - runtime_Semacquire(&sema) + time.Sleep(100 * time.Millisecond) } } } @@ -224,6 +235,3 @@ func fmtCause(cause any) error { return fmt.Errorf("%v", cause) } } - -//go:linkname runtime_Semacquire sync.runtime_Semacquire -func runtime_Semacquire(s *uint32) diff --git a/internal/task/task.go b/internal/task/task.go index ebbb414..abb45cf 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -5,7 +5,6 @@ package task import ( "context" - "errors" "sync" "sync/atomic" "time" @@ -27,10 +26,8 @@ type ( Finish(reason any) } Callback struct { - fn func() - about string - waitChildren bool - done atomic.Bool + fn func() + about string } // Task controls objects' lifetime. // @@ -40,13 +37,13 @@ type ( Task struct { name string - parent *Task - children childrenSet - callbacks callbacksSet + parent *Task + children childrenSet + callbacksOnFinish callbacksSet + callbacksOnCancel callbacksSet - cause error - - canceled chan struct{} + ctx context.Context + cancel context.CancelCauseFunc finished atomic.Bool mu sync.Mutex @@ -68,36 +65,12 @@ type ( const taskTimeout = 3 * time.Second func (t *Task) Context() context.Context { - return t -} - -func (t *Task) Deadline() (time.Time, bool) { - return time.Time{}, false -} - -func (t *Task) Done() <-chan struct{} { - return t.canceled -} - -func (t *Task) Err() error { - t.mu.Lock() - defer t.mu.Unlock() - if t.cause == nil { - return context.Canceled - } - return t.cause -} - -func (t *Task) Value(_ any) any { - return nil + return t.ctx } // FinishCause returns the reason / error that caused the task to be finished. func (t *Task) FinishCause() error { - if t.cause == nil || errors.Is(t.cause, context.Canceled) { - return nil - } - return t.cause + return context.Cause(t.ctx) } // OnFinished calls fn when the task is canceled and all subtasks are finished. @@ -118,38 +91,44 @@ func (t *Task) OnCancel(about string, fn func()) { // then marks the task as finished, with the given reason (if any). func (t *Task) Finish(reason any) { t.mu.Lock() - if t.cause != nil { + if t.isCanceled() { t.mu.Unlock() return } - cause := fmtCause(reason) - t.setCause(cause) - // t does not need finish, it shares the canceled channel with its parent - if t == root || t.canceled != t.parent.canceled { - close(t.canceled) - } + t.cancel(fmtCause(reason)) + t.ctx, t.cancel = cancelCtx, nil + t.mu.Unlock() t.finishAndWait() + t.finished.Store(true) } func (t *Task) finishAndWait() { - defer putTask(t) + ok := true - t.finishChildren() - t.runCallbacks() + if !waitEmpty(t.children, taskTimeout) { + t.reportStucked() + ok = false + } + t.runOnFinishCallbacks() if !t.waitFinish(taskTimeout) { t.reportStucked() + ok = false } // clear anyway clear(t.children) - clear(t.callbacks) + clear(t.callbacksOnFinish) - if t != root { + if t != root && t.needFinish() { t.parent.removeChild(t) } logFinished(t) + + if ok { + putTask(t) + } } func (t *Task) isFinished() bool { @@ -179,7 +158,7 @@ func (t *Task) Name() string { // String returns the full name of the task. func (t *Task) String() string { - if t.parent != nil { + if t.parent != root { return t.parent.String() + "." + t.name } return t.name @@ -192,7 +171,6 @@ func (t *Task) MarshalText() ([]byte, error) { func invokeWithRecover(cb *Callback) { defer func() { - cb.done.Store(true) if err := recover(); err != nil { log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic") panicWithDebugStack() diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 889e5c6..0341a02 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -113,7 +113,7 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) { ExpectTrue(t, finished) ExpectTrue(t, root.waitFinish(1*time.Second)) - ExpectError(t, ErrProgramExiting, context.Cause(task.Context())) + ExpectError(t, context.Canceled, context.Cause(task.Context())) ExpectError(t, ErrProgramExiting, task.Context().Err()) ExpectError(t, ErrProgramExiting, task.FinishCause()) } diff --git a/internal/task/utils.go b/internal/task/utils.go index d45fbee..7994548 100644 --- a/internal/task/utils.go +++ b/internal/task/utils.go @@ -21,11 +21,11 @@ func RootTask(name string, needFinish bool) *Task { } func RootContext() context.Context { - return root + return root.Context() } func RootContextCanceled() <-chan struct{} { - return root.Done() + return root.Context().Done() } func OnProgramExit(about string, fn func()) { @@ -59,10 +59,18 @@ func WaitExit(shutdownTimeout int) { // still running when the timeout was reached, and their current tree // of subtasks. func gracefulShutdown(timeout time.Duration) error { - root.Finish(ErrProgramExiting) - root.finishChildren() - root.runCallbacks() - if !root.waitFinish(timeout) { + root.mu.Lock() + if root.isCanceled() { + cause := context.Cause(root.ctx) + root.mu.Unlock() + return cause + } + root.mu.Unlock() + + root.cancel(ErrProgramExiting) + ok := waitEmpty(root.children, timeout) + root.runOnFinishCallbacks() + if !ok || !root.waitFinish(timeout) { return context.DeadlineExceeded } return nil diff --git a/internal/watcher/health/monitor/monitor.go b/internal/watcher/health/monitor/monitor.go index 9e0ead7..e7930da 100644 --- a/internal/watcher/health/monitor/monitor.go +++ b/internal/watcher/health/monitor/monitor.go @@ -95,7 +95,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error { defer func() { if mon.status.Load() != health.StatusError { - mon.status.Store(health.StatusUnknown) + mon.status.Store(health.StatusUnhealthy) } mon.task.Finish(nil) }()