diff --git a/internal/task/debug.go b/internal/task/debug.go new file mode 100644 index 0000000..7e85b74 --- /dev/null +++ b/internal/task/debug.go @@ -0,0 +1,26 @@ +package task + +import "strings" + +// debug only. +func (t *Task) listChildren() []string { + var children []string + allTasks.Range(func(child *Task) bool { + if child.parent == t { + children = append(children, strings.TrimPrefix(child.name, t.name+".")) + } + return true + }) + return children +} + +// debug only. +func (t *Task) listCallbacks() []string { + var callbacks []string + t.mu.Lock() + defer t.mu.Unlock() + for c := range t.callbacks { + callbacks = append(callbacks, c.about) + } + return callbacks +} diff --git a/internal/task/impl.go b/internal/task/impl.go new file mode 100644 index 0000000..3c8bd79 --- /dev/null +++ b/internal/task/impl.go @@ -0,0 +1,75 @@ +package task + +import ( + "errors" + "fmt" + "sync/atomic" + "time" +) + +func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) { + t.mu.Lock() + defer t.mu.Unlock() + if t.callbacks == nil { + t.callbacks = make(map[*Callback]struct{}) + t.callbacksDone = make(chan struct{}) + } + t.callbacks[&Callback{fn, about, waitSubTasks}] = struct{}{} +} + +func (t *Task) addChildCount() { + if atomic.AddUint32(&t.children, 1) == 1 { + t.mu.Lock() + if t.childrenDone == nil { + t.childrenDone = make(chan struct{}) + } + t.mu.Unlock() + } +} + +func (t *Task) subChildCount() { + if atomic.AddUint32(&t.children, ^uint32(0)) == 0 { + close(t.childrenDone) + } +} + +func (t *Task) runCallbacks() { + t.mu.Lock() + defer t.mu.Unlock() + if len(t.callbacks) == 0 { + return + } + for c := range t.callbacks { + if c.waitChildren { + waitWithTimeout(t.childrenDone) + } + t.invokeWithRecover(c.fn, c.about) + delete(t.callbacks, c) + } + close(t.callbacksDone) +} + +func waitWithTimeout(ch <-chan struct{}) bool { + if ch == nil { + return true + } + select { + case <-ch: + return true + case <-time.After(taskTimeout): + return false + } +} + +func fmtCause(cause any) error { + switch cause := cause.(type) { + case nil: + return nil + case error: + return cause + case string: + return errors.New(cause) + default: + return fmt.Errorf("%v", cause) + } +} diff --git a/internal/task/task.go b/internal/task/task.go index a8b6cde..7b46bac 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -2,10 +2,7 @@ package task import ( "context" - "errors" - "fmt" "runtime/debug" - "strings" "sync" "time" @@ -26,6 +23,11 @@ type ( TaskFinisher interface { Finish(reason any) } + Callback struct { + fn func() + about string + waitChildren bool + } // Task controls objects' lifetime. // // Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface. @@ -34,15 +36,19 @@ type ( Task struct { name string - children sync.WaitGroup + parent *Task + children uint32 + childrenDone chan struct{} - onFinished sync.WaitGroup - finished chan struct{} + callbacks map[*Callback]struct{} + callbacksDone chan struct{} + + finished chan struct{} + finishedCalled bool + mu sync.Mutex ctx context.Context cancel context.CancelCauseFunc - - once sync.Once } Parent interface { Context() context.Context @@ -72,96 +78,55 @@ func (t *Task) FinishCause() error { // // It should not be called after Finish is called. func (t *Task) OnFinished(about string, fn func()) { - t.onCancel(about, fn, true) + t.addCallback(about, fn, true) } // OnCancel calls fn when the task is canceled. // // It should not be called after Finish is called. func (t *Task) OnCancel(about string, fn func()) { - t.onCancel(about, fn, false) -} - -func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) { - t.onFinished.Add(1) - go func() { - <-t.ctx.Done() - if waitSubTasks { - waitWithTimeout(&t.children) - } - t.invokeWithRecover(fn, about) - t.onFinished.Done() - }() + t.addCallback(about, fn, false) } // Finish cancel all subtasks and wait for them to finish, // then marks the task as finished, with the given reason (if any). func (t *Task) Finish(reason any) { - t.once.Do(func() { - t.finish(reason) - }) + t.mu.Lock() + if t.finishedCalled { + t.mu.Unlock() + return + } + t.finishedCalled = true + t.mu.Unlock() + t.finish(reason) } func (t *Task) finish(reason any) { t.cancel(fmtCause(reason)) - if !waitWithTimeout(&t.children) { - logger.Debug(). - Strs("subtasks", t.listChildren()). - Msg("Timeout waiting for these subtasks to finish") - } - if !waitWithTimeout(&t.onFinished) { + if !waitWithTimeout(t.childrenDone) { logger.Debug(). Str("task", t.name). + Strs("subtasks", t.listChildren()). + Msg("Timeout waiting for subtasks to finish") + } + go t.runCallbacks() + if !waitWithTimeout(t.callbacksDone) { + logger.Debug(). + Str("task", t.name). + Strs("callbacks", t.listCallbacks()). Msg("Timeout waiting for callbacks to finish") } if t.finished != nil { close(t.finished) } + if t == root { + return + } + t.parent.subChildCount() + allTasks.Remove(t) logger.Trace().Msg("task " + t.name + " finished") } -// debug only. -func (t *Task) listChildren() []string { - var children []string - allTasks.Range(func(child *Task) bool { - if strings.HasPrefix(child.name, t.name+".") { - children = append(children, child.name) - } - return true - }) - return children -} - -func waitWithTimeout(wg *sync.WaitGroup) bool { - done := make(chan struct{}) - timeout := time.After(taskTimeout) - - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - return true - case <-timeout: - return false - } -} - -func fmtCause(cause any) error { - switch cause := cause.(type) { - case nil: - return nil - case error: - return cause - case string: - return errors.New(cause) - default: - return fmt.Errorf("%v", cause) - } -} - // Subtask returns a new subtask with the given name, derived from the parent's context. // // This should not be called after Finish is called. @@ -170,19 +135,19 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task { ctx, cancel := context.WithCancelCause(t.ctx) child := &Task{ - finished: make(chan struct{}, 1), + parent: t, + finished: make(chan struct{}), ctx: ctx, cancel: cancel, } if t != root { child.name = t.name + "." + name - allTasks.Add(child) } else { child.name = name } - allTasksWg.Add(1) - t.children.Add(1) + allTasks.Add(child) + t.addChildCount() if !nf { go func() { @@ -191,13 +156,6 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task { }() } - go func() { - <-child.finished - allTasksWg.Done() - t.children.Done() - allTasks.Remove(child) - }() - logger.Trace().Msg("task " + child.name + " started") return child } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index c5b866e..64b4d56 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -10,7 +10,7 @@ import ( ) func testTask() *Task { - return RootTask("test", false) + return RootTask("test", true) } func TestChildTaskCancellation(t *testing.T) { @@ -34,7 +34,7 @@ func TestChildTaskCancellation(t *testing.T) { parent.cancel(nil) // should also cancel child select { - case <-child.Finished(): + case <-child.Context().Done(): ExpectError(t, context.Canceled, child.Context().Err()) default: t.Fatal("subTask context was not canceled as expected") diff --git a/internal/task/utils.go b/internal/task/utils.go index 2fa0545..d6251e6 100644 --- a/internal/task/utils.go +++ b/internal/task/utils.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "slices" - "sync" "time" "github.com/yusing/go-proxy/internal/logging" @@ -17,25 +16,24 @@ var ErrProgramExiting = errors.New("program exiting") var logger = logging.With().Str("module", "task").Logger() var ( - root = newRoot() - allTasks = F.NewSet[*Task]() - allTasksWg sync.WaitGroup + root = newRoot() + allTasks = F.NewSet[*Task]() ) func testCleanup() { root = newRoot() allTasks.Clear() - allTasksWg = sync.WaitGroup{} } // RootTask returns a new Task with the given name, derived from the root context. -func RootTask(name string, needFinish bool) *Task { - return root.Subtask(name, needFinish) +func RootTask(name string, needFinish ...bool) *Task { + return root.Subtask(name, needFinish...) } func newRoot() *Task { t := &Task{name: "root"} t.ctx, t.cancel = context.WithCancelCause(context.Background()) + t.callbacks = make(map[*Callback]struct{}) return t } @@ -57,19 +55,12 @@ func OnProgramExit(about string, fn func()) { // still running when the timeout was reached, and their current tree // of subtasks. func GracefulShutdown(timeout time.Duration) (err error) { - root.cancel(ErrProgramExiting) + go root.Finish(ErrProgramExiting) - done := make(chan struct{}) after := time.After(timeout) - - go func() { - allTasksWg.Wait() - close(done) - }() - for { select { - case <-done: + case <-root.finished: return case <-after: b, err := json.Marshal(DebugTaskList())