fix(task): revert to context based approach and fix tasks stuck, improve error handling
Some checks failed
Docker Image CI (nightly) / build-nightly (push) Has been cancelled
Docker Image CI (nightly) / build-nightly-agent (push) Has been cancelled

This commit is contained in:
yusing 2025-05-26 00:32:59 +08:00
parent 2e9f113224
commit 216c03c5ff
11 changed files with 200 additions and 183 deletions

View file

@ -47,7 +47,7 @@ const (
)
func initClientCleaner() {
cleaner := task.RootTask("docker_clients_cleaner", false)
cleaner := task.RootTask("docker_clients_cleaner", true)
go func() {
ticker := time.NewTicker(cleanInterval)
defer ticker.Stop()

View file

@ -155,10 +155,11 @@ func (p *Poller[T, AggregateT]) Start() {
gatherErrsTicker.Stop()
saveTicker.Stop()
if err := p.save(); err != nil {
err := p.save()
if err != nil {
l.Err(err).Msg("failed to save metrics data")
}
t.Finish(nil)
t.Finish(err)
}()
l.Debug().Dur("interval", pollInterval).Msg("Starting poller")

View file

@ -8,11 +8,15 @@ import (
"github.com/rs/zerolog"
)
func HandleError(logger *zerolog.Logger, err error, msg string) {
func convertError(err error) error {
switch {
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
return
return nil
default:
logger.Fatal().Err(err).Msg(msg)
return err
}
}
func HandleError(logger *zerolog.Logger, err error, msg string) {
logger.Fatal().Err(err).Msg(msg)
}

View file

@ -104,7 +104,10 @@ func (s *Server) Start(parent task.Parent) {
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
}
Start(subtask, h3, s.acl, &s.l)
if s.http != nil {
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
}
// s.https is not nil (checked above)
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
}
@ -120,7 +123,7 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l
setDebugLogger(srv, logger)
proto := proto(srv)
task := parent.Subtask(proto, false)
task := parent.Subtask(proto, true)
var lc net.ListenConfig
var serveFunc func() error
@ -158,9 +161,13 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l
})
logStarted(srv, logger)
go func() {
err := serveFunc()
err := convertError(serveFunc())
if err != nil {
HandleError(logger, err, "failed to serve "+proto+" server")
}
task.Finish(err)
}()
return port
}
func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
@ -173,7 +180,7 @@ func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
if err := convertError(srv.Shutdown(ctx)); err != nil {
HandleError(logger, err, "failed to shutdown "+proto+" server")
} else {
logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped")

View file

@ -99,6 +99,11 @@ func (r *StreamRoute) HealthMonitor() health.HealthMonitor {
func (r *StreamRoute) acceptConnections() {
defer r.task.Finish("listener closed")
go func() {
<-r.task.Context().Done()
r.Close()
}()
for {
select {
case <-r.task.Context().Done():

View file

@ -1,23 +1,27 @@
package task
import (
"fmt"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
)
// debug only.
func (t *Task) listStuckedCallbacks() []string {
callbacks := make([]string, 0, len(t.callbacks))
for c := range t.callbacks {
if !c.done.Load() {
callbacks = append(callbacks, c.about)
}
t.mu.Lock()
defer t.mu.Unlock()
callbacks := make([]string, 0, len(t.callbacksOnFinish))
for c := range t.callbacksOnFinish {
callbacks = append(callbacks, c.about)
}
return callbacks
}
// debug only.
func (t *Task) listStuckedChildren() []string {
t.mu.Lock()
defer t.mu.Unlock()
children := make([]string, 0, len(t.children))
for c := range t.children {
if c.isFinished() {
@ -34,13 +38,15 @@ func (t *Task) listStuckedChildren() []string {
func (t *Task) reportStucked() {
callbacks := t.listStuckedCallbacks()
children := t.listStuckedChildren()
fmtOutput := gperr.Multiline().
Addf("stucked callbacks: %d, stucked children: %d",
len(callbacks), len(children),
).
Addf("callbacks").
AddLinesString(callbacks...).
Addf("children").
AddLinesString(children...)
log.Warn().Msg(fmtOutput.Error())
if len(callbacks) == 0 && len(children) == 0 {
return
}
fmtOutput := gperr.NewBuilder(fmt.Sprintf("%s stucked callbacks: %d, stucked children: %d", t.String(), len(callbacks), len(children)))
if len(callbacks) > 0 {
fmtOutput.Add(gperr.New("callbacks").With(gperr.Multiline().AddLinesString(callbacks...)))
}
if len(children) > 0 {
fmtOutput.Add(gperr.New("children").With(gperr.Multiline().AddLinesString(children...)))
}
log.Warn().Msg(fmtOutput.String())
}

View file

@ -4,23 +4,38 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"
_ "unsafe"
"github.com/rs/zerolog/log"
)
var (
taskPool = make(chan *Task, 100)
root = newRoot()
voidTask = &Task{ctx: context.Background()}
root = newRoot()
cancelCtx context.Context
)
func init() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
cancelCtx = ctx //nolint:fatcontext
voidTask.parent = root
}
func testCleanup() {
root = newRoot()
}
func newRoot() *Task {
return newTask("root", nil, true)
return newTask("root", voidTask, true)
}
func noCancel(error) {
// do nothing
}
//go:inline
@ -28,20 +43,31 @@ func newTask(name string, parent *Task, needFinish bool) *Task {
var t *Task
select {
case t = <-taskPool:
t.finished.Store(false)
default:
t = &Task{}
}
t.name = name
t.parent = parent
if needFinish {
t.canceled = make(chan struct{})
t.ctx, t.cancel = context.WithCancelCause(parent.ctx)
} else {
// it will not be nil, because root task always has a canceled channel
t.canceled = parent.canceled
t.ctx, t.cancel = parent.ctx, noCancel
}
return t
}
//go:inline
func (t *Task) needFinish() bool {
return t.ctx != t.parent.ctx
}
//go:inline
func (t *Task) isCanceled() bool {
return t.cancel == nil
}
//go:inline
func putTask(t *Task) {
select {
case taskPool <- t:
@ -50,49 +76,76 @@ func putTask(t *Task) {
}
}
//go:inline
func (t *Task) setCause(cause error) {
if cause == nil {
t.cause = context.Canceled
} else {
t.cause = cause
}
}
//go:inline
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
t.mu.Lock()
if t.cause != nil {
t.mu.Unlock()
if !t.needFinish() {
if waitSubTasks {
waitEmpty(t.children, taskTimeout)
t.parent.addCallback(about, func() {
if !t.waitFinish(taskTimeout) {
t.reportStucked()
}
fn()
}, false)
} else {
t.parent.addCallback(about, fn, false)
}
fn()
return
}
defer t.mu.Unlock()
if t.callbacks == nil {
t.callbacks = make(callbacksSet)
if !waitSubTasks {
t.mu.Lock()
defer t.mu.Unlock()
if t.callbacksOnCancel == nil {
t.callbacksOnCancel = make(callbacksSet)
go func() {
<-t.ctx.Done()
for c := range t.callbacksOnCancel {
go func() {
invokeWithRecover(c)
t.mu.Lock()
delete(t.callbacksOnCancel, c)
t.mu.Unlock()
}()
}
}()
}
t.callbacksOnCancel[&Callback{fn: fn, about: about}] = struct{}{}
return
}
t.callbacks[&Callback{
fn: fn,
about: about,
waitChildren: waitSubTasks,
t.mu.Lock()
defer t.mu.Unlock()
if t.isCanceled() {
log.Panic().
Str("task", t.String()).
Str("callback", about).
Msg("callback added to canceled task")
return
}
if t.callbacksOnFinish == nil {
t.callbacksOnFinish = make(callbacksSet)
}
t.callbacksOnFinish[&Callback{
fn: fn,
about: about,
}] = struct{}{}
}
//go:inline
func (t *Task) addChild(child *Task) {
t.mu.Lock()
if t.cause != nil {
t.mu.Unlock()
child.Finish(t.FinishCause())
defer t.mu.Unlock()
if t.isCanceled() {
log.Panic().
Str("task", t.String()).
Str("child", child.Name()).
Msg("child added to canceled task")
return
}
defer t.mu.Unlock()
if t.children == nil {
t.children = make(childrenSet)
}
@ -106,67 +159,19 @@ func (t *Task) removeChild(child *Task) {
delete(t.children, child)
}
func (t *Task) finishChildren() {
t.mu.Lock()
if len(t.children) == 0 {
t.mu.Unlock()
func (t *Task) runOnFinishCallbacks() {
if len(t.callbacksOnFinish) == 0 {
return
}
var wg sync.WaitGroup
for child := range t.children {
wg.Add(1)
for c := range t.callbacksOnFinish {
go func() {
defer wg.Done()
child.Finish(t.cause)
invokeWithRecover(c)
t.mu.Lock()
delete(t.callbacksOnFinish, c)
t.mu.Unlock()
}()
}
clear(t.children)
t.mu.Unlock()
wg.Wait()
}
func (t *Task) runCallbacks() {
t.mu.Lock()
if len(t.callbacks) == 0 {
t.mu.Unlock()
return
}
var wg sync.WaitGroup
var needWait bool
// runs callbacks that does not need wait first
for c := range t.callbacks {
if !c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
} else {
needWait = true
}
}
// runs callbacks that need to wait for children
if needWait {
waitEmpty(t.children, taskTimeout)
for c := range t.callbacks {
if c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
}
}
}
clear(t.callbacks)
t.mu.Unlock()
wg.Wait()
}
func (t *Task) waitFinish(timeout time.Duration) bool {
@ -175,16 +180,24 @@ func (t *Task) waitFinish(timeout time.Duration) bool {
return true
}
if len(t.children) == 0 && len(t.callbacks) == 0 {
return true
t.mu.Lock()
children, callbacksOnCancel, callbacksOnFinish := t.children, t.callbacksOnCancel, t.callbacksOnFinish
t.mu.Unlock()
ok := true
if len(children) != 0 {
ok = waitEmpty(children, timeout)
}
ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout)
if !ok {
return false
if len(callbacksOnCancel) != 0 {
ok = ok && waitEmpty(callbacksOnCancel, timeout)
}
t.finished.Store(true)
return true
if len(callbacksOnFinish) != 0 {
ok = ok && waitEmpty(callbacksOnFinish, timeout)
}
return ok
}
//go:inline
@ -193,8 +206,6 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
return true
}
var sema uint32
timer := time.NewTimer(timeout)
defer timer.Stop()
@ -206,7 +217,7 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
case <-timer.C:
return false
default:
runtime_Semacquire(&sema)
time.Sleep(100 * time.Millisecond)
}
}
}
@ -224,6 +235,3 @@ func fmtCause(cause any) error {
return fmt.Errorf("%v", cause)
}
}
//go:linkname runtime_Semacquire sync.runtime_Semacquire
func runtime_Semacquire(s *uint32)

View file

@ -5,7 +5,6 @@ package task
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
@ -27,10 +26,8 @@ type (
Finish(reason any)
}
Callback struct {
fn func()
about string
waitChildren bool
done atomic.Bool
fn func()
about string
}
// Task controls objects' lifetime.
//
@ -40,13 +37,13 @@ type (
Task struct {
name string
parent *Task
children childrenSet
callbacks callbacksSet
parent *Task
children childrenSet
callbacksOnFinish callbacksSet
callbacksOnCancel callbacksSet
cause error
canceled chan struct{}
ctx context.Context
cancel context.CancelCauseFunc
finished atomic.Bool
mu sync.Mutex
@ -68,36 +65,12 @@ type (
const taskTimeout = 3 * time.Second
func (t *Task) Context() context.Context {
return t
}
func (t *Task) Deadline() (time.Time, bool) {
return time.Time{}, false
}
func (t *Task) Done() <-chan struct{} {
return t.canceled
}
func (t *Task) Err() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.cause == nil {
return context.Canceled
}
return t.cause
}
func (t *Task) Value(_ any) any {
return nil
return t.ctx
}
// FinishCause returns the reason / error that caused the task to be finished.
func (t *Task) FinishCause() error {
if t.cause == nil || errors.Is(t.cause, context.Canceled) {
return nil
}
return t.cause
return context.Cause(t.ctx)
}
// OnFinished calls fn when the task is canceled and all subtasks are finished.
@ -118,38 +91,44 @@ func (t *Task) OnCancel(about string, fn func()) {
// then marks the task as finished, with the given reason (if any).
func (t *Task) Finish(reason any) {
t.mu.Lock()
if t.cause != nil {
if t.isCanceled() {
t.mu.Unlock()
return
}
cause := fmtCause(reason)
t.setCause(cause)
// t does not need finish, it shares the canceled channel with its parent
if t == root || t.canceled != t.parent.canceled {
close(t.canceled)
}
t.cancel(fmtCause(reason))
t.ctx, t.cancel = cancelCtx, nil
t.mu.Unlock()
t.finishAndWait()
t.finished.Store(true)
}
func (t *Task) finishAndWait() {
defer putTask(t)
ok := true
t.finishChildren()
t.runCallbacks()
if !waitEmpty(t.children, taskTimeout) {
t.reportStucked()
ok = false
}
t.runOnFinishCallbacks()
if !t.waitFinish(taskTimeout) {
t.reportStucked()
ok = false
}
// clear anyway
clear(t.children)
clear(t.callbacks)
clear(t.callbacksOnFinish)
if t != root {
if t != root && t.needFinish() {
t.parent.removeChild(t)
}
logFinished(t)
if ok {
putTask(t)
}
}
func (t *Task) isFinished() bool {
@ -179,7 +158,7 @@ func (t *Task) Name() string {
// String returns the full name of the task.
func (t *Task) String() string {
if t.parent != nil {
if t.parent != root {
return t.parent.String() + "." + t.name
}
return t.name
@ -192,7 +171,6 @@ func (t *Task) MarshalText() ([]byte, error) {
func invokeWithRecover(cb *Callback) {
defer func() {
cb.done.Store(true)
if err := recover(); err != nil {
log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic")
panicWithDebugStack()

View file

@ -113,7 +113,7 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) {
ExpectTrue(t, finished)
ExpectTrue(t, root.waitFinish(1*time.Second))
ExpectError(t, ErrProgramExiting, context.Cause(task.Context()))
ExpectError(t, context.Canceled, context.Cause(task.Context()))
ExpectError(t, ErrProgramExiting, task.Context().Err())
ExpectError(t, ErrProgramExiting, task.FinishCause())
}

View file

@ -21,11 +21,11 @@ func RootTask(name string, needFinish bool) *Task {
}
func RootContext() context.Context {
return root
return root.Context()
}
func RootContextCanceled() <-chan struct{} {
return root.Done()
return root.Context().Done()
}
func OnProgramExit(about string, fn func()) {
@ -59,10 +59,18 @@ func WaitExit(shutdownTimeout int) {
// still running when the timeout was reached, and their current tree
// of subtasks.
func gracefulShutdown(timeout time.Duration) error {
root.Finish(ErrProgramExiting)
root.finishChildren()
root.runCallbacks()
if !root.waitFinish(timeout) {
root.mu.Lock()
if root.isCanceled() {
cause := context.Cause(root.ctx)
root.mu.Unlock()
return cause
}
root.mu.Unlock()
root.cancel(ErrProgramExiting)
ok := waitEmpty(root.children, timeout)
root.runOnFinishCallbacks()
if !ok || !root.waitFinish(timeout) {
return context.DeadlineExceeded
}
return nil

View file

@ -95,7 +95,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
defer func() {
if mon.status.Load() != health.StatusError {
mon.status.Store(health.StatusUnknown)
mon.status.Store(health.StatusUnhealthy)
}
mon.task.Finish(nil)
}()