fix(task): revert to context based approach and fix tasks stuck, improve error handling
Some checks failed
Docker Image CI (nightly) / build-nightly (push) Has been cancelled
Docker Image CI (nightly) / build-nightly-agent (push) Has been cancelled

This commit is contained in:
yusing 2025-05-26 00:32:59 +08:00
parent 2e9f113224
commit 216c03c5ff
11 changed files with 200 additions and 183 deletions

View file

@ -47,7 +47,7 @@ const (
) )
func initClientCleaner() { func initClientCleaner() {
cleaner := task.RootTask("docker_clients_cleaner", false) cleaner := task.RootTask("docker_clients_cleaner", true)
go func() { go func() {
ticker := time.NewTicker(cleanInterval) ticker := time.NewTicker(cleanInterval)
defer ticker.Stop() defer ticker.Stop()

View file

@ -155,10 +155,11 @@ func (p *Poller[T, AggregateT]) Start() {
gatherErrsTicker.Stop() gatherErrsTicker.Stop()
saveTicker.Stop() saveTicker.Stop()
if err := p.save(); err != nil { err := p.save()
if err != nil {
l.Err(err).Msg("failed to save metrics data") l.Err(err).Msg("failed to save metrics data")
} }
t.Finish(nil) t.Finish(err)
}() }()
l.Debug().Dur("interval", pollInterval).Msg("Starting poller") l.Debug().Dur("interval", pollInterval).Msg("Starting poller")

View file

@ -8,11 +8,15 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
func HandleError(logger *zerolog.Logger, err error, msg string) { func convertError(err error) error {
switch { switch {
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
return return nil
default: default:
logger.Fatal().Err(err).Msg(msg) return err
} }
} }
func HandleError(logger *zerolog.Logger, err error, msg string) {
logger.Fatal().Err(err).Msg(msg)
}

View file

@ -104,7 +104,10 @@ func (s *Server) Start(parent task.Parent) {
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig), TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
} }
Start(subtask, h3, s.acl, &s.l) Start(subtask, h3, s.acl, &s.l)
if s.http != nil {
s.http.Handler = advertiseHTTP3(s.http.Handler, h3) s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
}
// s.https is not nil (checked above)
s.https.Handler = advertiseHTTP3(s.https.Handler, h3) s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
} }
@ -120,7 +123,7 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l
setDebugLogger(srv, logger) setDebugLogger(srv, logger)
proto := proto(srv) proto := proto(srv)
task := parent.Subtask(proto, false) task := parent.Subtask(proto, true)
var lc net.ListenConfig var lc net.ListenConfig
var serveFunc func() error var serveFunc func() error
@ -158,9 +161,13 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l
}) })
logStarted(srv, logger) logStarted(srv, logger)
go func() { go func() {
err := serveFunc() err := convertError(serveFunc())
if err != nil {
HandleError(logger, err, "failed to serve "+proto+" server") HandleError(logger, err, "failed to serve "+proto+" server")
}
task.Finish(err)
}() }()
return port
} }
func stop[Server httpServer](srv Server, logger *zerolog.Logger) { func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
@ -173,7 +180,7 @@ func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second) ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(ctx); err != nil { if err := convertError(srv.Shutdown(ctx)); err != nil {
HandleError(logger, err, "failed to shutdown "+proto+" server") HandleError(logger, err, "failed to shutdown "+proto+" server")
} else { } else {
logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped") logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped")

View file

@ -99,6 +99,11 @@ func (r *StreamRoute) HealthMonitor() health.HealthMonitor {
func (r *StreamRoute) acceptConnections() { func (r *StreamRoute) acceptConnections() {
defer r.task.Finish("listener closed") defer r.task.Finish("listener closed")
go func() {
<-r.task.Context().Done()
r.Close()
}()
for { for {
select { select {
case <-r.task.Context().Done(): case <-r.task.Context().Done():

View file

@ -1,23 +1,27 @@
package task package task
import ( import (
"fmt"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
) )
// debug only. // debug only.
func (t *Task) listStuckedCallbacks() []string { func (t *Task) listStuckedCallbacks() []string {
callbacks := make([]string, 0, len(t.callbacks)) t.mu.Lock()
for c := range t.callbacks { defer t.mu.Unlock()
if !c.done.Load() { callbacks := make([]string, 0, len(t.callbacksOnFinish))
callbacks = append(callbacks, c.about) for c := range t.callbacksOnFinish {
} callbacks = append(callbacks, c.about)
} }
return callbacks return callbacks
} }
// debug only. // debug only.
func (t *Task) listStuckedChildren() []string { func (t *Task) listStuckedChildren() []string {
t.mu.Lock()
defer t.mu.Unlock()
children := make([]string, 0, len(t.children)) children := make([]string, 0, len(t.children))
for c := range t.children { for c := range t.children {
if c.isFinished() { if c.isFinished() {
@ -34,13 +38,15 @@ func (t *Task) listStuckedChildren() []string {
func (t *Task) reportStucked() { func (t *Task) reportStucked() {
callbacks := t.listStuckedCallbacks() callbacks := t.listStuckedCallbacks()
children := t.listStuckedChildren() children := t.listStuckedChildren()
fmtOutput := gperr.Multiline(). if len(callbacks) == 0 && len(children) == 0 {
Addf("stucked callbacks: %d, stucked children: %d", return
len(callbacks), len(children), }
). fmtOutput := gperr.NewBuilder(fmt.Sprintf("%s stucked callbacks: %d, stucked children: %d", t.String(), len(callbacks), len(children)))
Addf("callbacks"). if len(callbacks) > 0 {
AddLinesString(callbacks...). fmtOutput.Add(gperr.New("callbacks").With(gperr.Multiline().AddLinesString(callbacks...)))
Addf("children"). }
AddLinesString(children...) if len(children) > 0 {
log.Warn().Msg(fmtOutput.Error()) fmtOutput.Add(gperr.New("children").With(gperr.Multiline().AddLinesString(children...)))
}
log.Warn().Msg(fmtOutput.String())
} }

View file

@ -4,23 +4,38 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sync"
"time" "time"
_ "unsafe"
"github.com/rs/zerolog/log"
) )
var ( var (
taskPool = make(chan *Task, 100) taskPool = make(chan *Task, 100)
root = newRoot() voidTask = &Task{ctx: context.Background()}
root = newRoot()
cancelCtx context.Context
) )
func init() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
cancelCtx = ctx //nolint:fatcontext
voidTask.parent = root
}
func testCleanup() { func testCleanup() {
root = newRoot() root = newRoot()
} }
func newRoot() *Task { func newRoot() *Task {
return newTask("root", nil, true) return newTask("root", voidTask, true)
}
func noCancel(error) {
// do nothing
} }
//go:inline //go:inline
@ -28,20 +43,31 @@ func newTask(name string, parent *Task, needFinish bool) *Task {
var t *Task var t *Task
select { select {
case t = <-taskPool: case t = <-taskPool:
t.finished.Store(false)
default: default:
t = &Task{} t = &Task{}
} }
t.name = name t.name = name
t.parent = parent t.parent = parent
if needFinish { if needFinish {
t.canceled = make(chan struct{}) t.ctx, t.cancel = context.WithCancelCause(parent.ctx)
} else { } else {
// it will not be nil, because root task always has a canceled channel t.ctx, t.cancel = parent.ctx, noCancel
t.canceled = parent.canceled
} }
return t return t
} }
//go:inline
func (t *Task) needFinish() bool {
return t.ctx != t.parent.ctx
}
//go:inline
func (t *Task) isCanceled() bool {
return t.cancel == nil
}
//go:inline
func putTask(t *Task) { func putTask(t *Task) {
select { select {
case taskPool <- t: case taskPool <- t:
@ -50,49 +76,76 @@ func putTask(t *Task) {
} }
} }
//go:inline
func (t *Task) setCause(cause error) {
if cause == nil {
t.cause = context.Canceled
} else {
t.cause = cause
}
}
//go:inline //go:inline
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) { func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
t.mu.Lock() if !t.needFinish() {
if t.cause != nil {
t.mu.Unlock()
if waitSubTasks { if waitSubTasks {
waitEmpty(t.children, taskTimeout) t.parent.addCallback(about, func() {
if !t.waitFinish(taskTimeout) {
t.reportStucked()
}
fn()
}, false)
} else {
t.parent.addCallback(about, fn, false)
} }
fn()
return return
} }
defer t.mu.Unlock() if !waitSubTasks {
if t.callbacks == nil { t.mu.Lock()
t.callbacks = make(callbacksSet) defer t.mu.Unlock()
if t.callbacksOnCancel == nil {
t.callbacksOnCancel = make(callbacksSet)
go func() {
<-t.ctx.Done()
for c := range t.callbacksOnCancel {
go func() {
invokeWithRecover(c)
t.mu.Lock()
delete(t.callbacksOnCancel, c)
t.mu.Unlock()
}()
}
}()
}
t.callbacksOnCancel[&Callback{fn: fn, about: about}] = struct{}{}
return
} }
t.callbacks[&Callback{
fn: fn, t.mu.Lock()
about: about, defer t.mu.Unlock()
waitChildren: waitSubTasks,
if t.isCanceled() {
log.Panic().
Str("task", t.String()).
Str("callback", about).
Msg("callback added to canceled task")
return
}
if t.callbacksOnFinish == nil {
t.callbacksOnFinish = make(callbacksSet)
}
t.callbacksOnFinish[&Callback{
fn: fn,
about: about,
}] = struct{}{} }] = struct{}{}
} }
//go:inline //go:inline
func (t *Task) addChild(child *Task) { func (t *Task) addChild(child *Task) {
t.mu.Lock() t.mu.Lock()
if t.cause != nil { defer t.mu.Unlock()
t.mu.Unlock()
child.Finish(t.FinishCause()) if t.isCanceled() {
log.Panic().
Str("task", t.String()).
Str("child", child.Name()).
Msg("child added to canceled task")
return return
} }
defer t.mu.Unlock()
if t.children == nil { if t.children == nil {
t.children = make(childrenSet) t.children = make(childrenSet)
} }
@ -106,67 +159,19 @@ func (t *Task) removeChild(child *Task) {
delete(t.children, child) delete(t.children, child)
} }
func (t *Task) finishChildren() { func (t *Task) runOnFinishCallbacks() {
t.mu.Lock() if len(t.callbacksOnFinish) == 0 {
if len(t.children) == 0 {
t.mu.Unlock()
return return
} }
var wg sync.WaitGroup for c := range t.callbacksOnFinish {
for child := range t.children {
wg.Add(1)
go func() { go func() {
defer wg.Done() invokeWithRecover(c)
child.Finish(t.cause) t.mu.Lock()
delete(t.callbacksOnFinish, c)
t.mu.Unlock()
}() }()
} }
clear(t.children)
t.mu.Unlock()
wg.Wait()
}
func (t *Task) runCallbacks() {
t.mu.Lock()
if len(t.callbacks) == 0 {
t.mu.Unlock()
return
}
var wg sync.WaitGroup
var needWait bool
// runs callbacks that does not need wait first
for c := range t.callbacks {
if !c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
} else {
needWait = true
}
}
// runs callbacks that need to wait for children
if needWait {
waitEmpty(t.children, taskTimeout)
for c := range t.callbacks {
if c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
}
}
}
clear(t.callbacks)
t.mu.Unlock()
wg.Wait()
} }
func (t *Task) waitFinish(timeout time.Duration) bool { func (t *Task) waitFinish(timeout time.Duration) bool {
@ -175,16 +180,24 @@ func (t *Task) waitFinish(timeout time.Duration) bool {
return true return true
} }
if len(t.children) == 0 && len(t.callbacks) == 0 { t.mu.Lock()
return true children, callbacksOnCancel, callbacksOnFinish := t.children, t.callbacksOnCancel, t.callbacksOnFinish
t.mu.Unlock()
ok := true
if len(children) != 0 {
ok = waitEmpty(children, timeout)
} }
ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout) if len(callbacksOnCancel) != 0 {
if !ok { ok = ok && waitEmpty(callbacksOnCancel, timeout)
return false
} }
t.finished.Store(true)
return true if len(callbacksOnFinish) != 0 {
ok = ok && waitEmpty(callbacksOnFinish, timeout)
}
return ok
} }
//go:inline //go:inline
@ -193,8 +206,6 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
return true return true
} }
var sema uint32
timer := time.NewTimer(timeout) timer := time.NewTimer(timeout)
defer timer.Stop() defer timer.Stop()
@ -206,7 +217,7 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
case <-timer.C: case <-timer.C:
return false return false
default: default:
runtime_Semacquire(&sema) time.Sleep(100 * time.Millisecond)
} }
} }
} }
@ -224,6 +235,3 @@ func fmtCause(cause any) error {
return fmt.Errorf("%v", cause) return fmt.Errorf("%v", cause)
} }
} }
//go:linkname runtime_Semacquire sync.runtime_Semacquire
func runtime_Semacquire(s *uint32)

View file

@ -5,7 +5,6 @@ package task
import ( import (
"context" "context"
"errors"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -27,10 +26,8 @@ type (
Finish(reason any) Finish(reason any)
} }
Callback struct { Callback struct {
fn func() fn func()
about string about string
waitChildren bool
done atomic.Bool
} }
// Task controls objects' lifetime. // Task controls objects' lifetime.
// //
@ -40,13 +37,13 @@ type (
Task struct { Task struct {
name string name string
parent *Task parent *Task
children childrenSet children childrenSet
callbacks callbacksSet callbacksOnFinish callbacksSet
callbacksOnCancel callbacksSet
cause error ctx context.Context
cancel context.CancelCauseFunc
canceled chan struct{}
finished atomic.Bool finished atomic.Bool
mu sync.Mutex mu sync.Mutex
@ -68,36 +65,12 @@ type (
const taskTimeout = 3 * time.Second const taskTimeout = 3 * time.Second
func (t *Task) Context() context.Context { func (t *Task) Context() context.Context {
return t return t.ctx
}
func (t *Task) Deadline() (time.Time, bool) {
return time.Time{}, false
}
func (t *Task) Done() <-chan struct{} {
return t.canceled
}
func (t *Task) Err() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.cause == nil {
return context.Canceled
}
return t.cause
}
func (t *Task) Value(_ any) any {
return nil
} }
// FinishCause returns the reason / error that caused the task to be finished. // FinishCause returns the reason / error that caused the task to be finished.
func (t *Task) FinishCause() error { func (t *Task) FinishCause() error {
if t.cause == nil || errors.Is(t.cause, context.Canceled) { return context.Cause(t.ctx)
return nil
}
return t.cause
} }
// OnFinished calls fn when the task is canceled and all subtasks are finished. // OnFinished calls fn when the task is canceled and all subtasks are finished.
@ -118,38 +91,44 @@ func (t *Task) OnCancel(about string, fn func()) {
// then marks the task as finished, with the given reason (if any). // then marks the task as finished, with the given reason (if any).
func (t *Task) Finish(reason any) { func (t *Task) Finish(reason any) {
t.mu.Lock() t.mu.Lock()
if t.cause != nil { if t.isCanceled() {
t.mu.Unlock() t.mu.Unlock()
return return
} }
cause := fmtCause(reason) t.cancel(fmtCause(reason))
t.setCause(cause) t.ctx, t.cancel = cancelCtx, nil
// t does not need finish, it shares the canceled channel with its parent
if t == root || t.canceled != t.parent.canceled {
close(t.canceled)
}
t.mu.Unlock() t.mu.Unlock()
t.finishAndWait() t.finishAndWait()
t.finished.Store(true)
} }
func (t *Task) finishAndWait() { func (t *Task) finishAndWait() {
defer putTask(t) ok := true
t.finishChildren() if !waitEmpty(t.children, taskTimeout) {
t.runCallbacks() t.reportStucked()
ok = false
}
t.runOnFinishCallbacks()
if !t.waitFinish(taskTimeout) { if !t.waitFinish(taskTimeout) {
t.reportStucked() t.reportStucked()
ok = false
} }
// clear anyway // clear anyway
clear(t.children) clear(t.children)
clear(t.callbacks) clear(t.callbacksOnFinish)
if t != root { if t != root && t.needFinish() {
t.parent.removeChild(t) t.parent.removeChild(t)
} }
logFinished(t) logFinished(t)
if ok {
putTask(t)
}
} }
func (t *Task) isFinished() bool { func (t *Task) isFinished() bool {
@ -179,7 +158,7 @@ func (t *Task) Name() string {
// String returns the full name of the task. // String returns the full name of the task.
func (t *Task) String() string { func (t *Task) String() string {
if t.parent != nil { if t.parent != root {
return t.parent.String() + "." + t.name return t.parent.String() + "." + t.name
} }
return t.name return t.name
@ -192,7 +171,6 @@ func (t *Task) MarshalText() ([]byte, error) {
func invokeWithRecover(cb *Callback) { func invokeWithRecover(cb *Callback) {
defer func() { defer func() {
cb.done.Store(true)
if err := recover(); err != nil { if err := recover(); err != nil {
log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic") log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic")
panicWithDebugStack() panicWithDebugStack()

View file

@ -113,7 +113,7 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) {
ExpectTrue(t, finished) ExpectTrue(t, finished)
ExpectTrue(t, root.waitFinish(1*time.Second)) ExpectTrue(t, root.waitFinish(1*time.Second))
ExpectError(t, ErrProgramExiting, context.Cause(task.Context())) ExpectError(t, context.Canceled, context.Cause(task.Context()))
ExpectError(t, ErrProgramExiting, task.Context().Err()) ExpectError(t, ErrProgramExiting, task.Context().Err())
ExpectError(t, ErrProgramExiting, task.FinishCause()) ExpectError(t, ErrProgramExiting, task.FinishCause())
} }

View file

@ -21,11 +21,11 @@ func RootTask(name string, needFinish bool) *Task {
} }
func RootContext() context.Context { func RootContext() context.Context {
return root return root.Context()
} }
func RootContextCanceled() <-chan struct{} { func RootContextCanceled() <-chan struct{} {
return root.Done() return root.Context().Done()
} }
func OnProgramExit(about string, fn func()) { func OnProgramExit(about string, fn func()) {
@ -59,10 +59,18 @@ func WaitExit(shutdownTimeout int) {
// still running when the timeout was reached, and their current tree // still running when the timeout was reached, and their current tree
// of subtasks. // of subtasks.
func gracefulShutdown(timeout time.Duration) error { func gracefulShutdown(timeout time.Duration) error {
root.Finish(ErrProgramExiting) root.mu.Lock()
root.finishChildren() if root.isCanceled() {
root.runCallbacks() cause := context.Cause(root.ctx)
if !root.waitFinish(timeout) { root.mu.Unlock()
return cause
}
root.mu.Unlock()
root.cancel(ErrProgramExiting)
ok := waitEmpty(root.children, timeout)
root.runOnFinishCallbacks()
if !ok || !root.waitFinish(timeout) {
return context.DeadlineExceeded return context.DeadlineExceeded
} }
return nil return nil

View file

@ -95,7 +95,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
defer func() { defer func() {
if mon.status.Load() != health.StatusError { if mon.status.Load() != health.StatusError {
mon.status.Store(health.StatusUnknown) mon.status.Store(health.StatusUnhealthy)
} }
mon.task.Finish(nil) mon.task.Finish(nil)
}() }()