task package: replace waitgroup with channel, fix stuck

This commit is contained in:
yusing 2025-01-02 11:12:13 +08:00
parent af14966b09
commit 2fe0b888bd
5 changed files with 152 additions and 102 deletions

26
internal/task/debug.go Normal file
View file

@ -0,0 +1,26 @@
package task
import "strings"
// debug only.
func (t *Task) listChildren() []string {
var children []string
allTasks.Range(func(child *Task) bool {
if child.parent == t {
children = append(children, strings.TrimPrefix(child.name, t.name+"."))
}
return true
})
return children
}
// debug only.
func (t *Task) listCallbacks() []string {
var callbacks []string
t.mu.Lock()
defer t.mu.Unlock()
for c := range t.callbacks {
callbacks = append(callbacks, c.about)
}
return callbacks
}

75
internal/task/impl.go Normal file
View file

@ -0,0 +1,75 @@
package task
import (
"errors"
"fmt"
"sync/atomic"
"time"
)
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
t.mu.Lock()
defer t.mu.Unlock()
if t.callbacks == nil {
t.callbacks = make(map[*Callback]struct{})
t.callbacksDone = make(chan struct{})
}
t.callbacks[&Callback{fn, about, waitSubTasks}] = struct{}{}
}
func (t *Task) addChildCount() {
if atomic.AddUint32(&t.children, 1) == 1 {
t.mu.Lock()
if t.childrenDone == nil {
t.childrenDone = make(chan struct{})
}
t.mu.Unlock()
}
}
func (t *Task) subChildCount() {
if atomic.AddUint32(&t.children, ^uint32(0)) == 0 {
close(t.childrenDone)
}
}
func (t *Task) runCallbacks() {
t.mu.Lock()
defer t.mu.Unlock()
if len(t.callbacks) == 0 {
return
}
for c := range t.callbacks {
if c.waitChildren {
waitWithTimeout(t.childrenDone)
}
t.invokeWithRecover(c.fn, c.about)
delete(t.callbacks, c)
}
close(t.callbacksDone)
}
func waitWithTimeout(ch <-chan struct{}) bool {
if ch == nil {
return true
}
select {
case <-ch:
return true
case <-time.After(taskTimeout):
return false
}
}
func fmtCause(cause any) error {
switch cause := cause.(type) {
case nil:
return nil
case error:
return cause
case string:
return errors.New(cause)
default:
return fmt.Errorf("%v", cause)
}
}

View file

@ -2,10 +2,7 @@ package task
import ( import (
"context" "context"
"errors"
"fmt"
"runtime/debug" "runtime/debug"
"strings"
"sync" "sync"
"time" "time"
@ -26,6 +23,11 @@ type (
TaskFinisher interface { TaskFinisher interface {
Finish(reason any) Finish(reason any)
} }
Callback struct {
fn func()
about string
waitChildren bool
}
// Task controls objects' lifetime. // Task controls objects' lifetime.
// //
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface. // Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
@ -34,15 +36,19 @@ type (
Task struct { Task struct {
name string name string
children sync.WaitGroup parent *Task
children uint32
childrenDone chan struct{}
callbacks map[*Callback]struct{}
callbacksDone chan struct{}
onFinished sync.WaitGroup
finished chan struct{} finished chan struct{}
finishedCalled bool
mu sync.Mutex
ctx context.Context ctx context.Context
cancel context.CancelCauseFunc cancel context.CancelCauseFunc
once sync.Once
} }
Parent interface { Parent interface {
Context() context.Context Context() context.Context
@ -72,96 +78,55 @@ func (t *Task) FinishCause() error {
// //
// It should not be called after Finish is called. // It should not be called after Finish is called.
func (t *Task) OnFinished(about string, fn func()) { func (t *Task) OnFinished(about string, fn func()) {
t.onCancel(about, fn, true) t.addCallback(about, fn, true)
} }
// OnCancel calls fn when the task is canceled. // OnCancel calls fn when the task is canceled.
// //
// It should not be called after Finish is called. // It should not be called after Finish is called.
func (t *Task) OnCancel(about string, fn func()) { func (t *Task) OnCancel(about string, fn func()) {
t.onCancel(about, fn, false) t.addCallback(about, fn, false)
}
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
t.onFinished.Add(1)
go func() {
<-t.ctx.Done()
if waitSubTasks {
waitWithTimeout(&t.children)
}
t.invokeWithRecover(fn, about)
t.onFinished.Done()
}()
} }
// Finish cancel all subtasks and wait for them to finish, // Finish cancel all subtasks and wait for them to finish,
// then marks the task as finished, with the given reason (if any). // then marks the task as finished, with the given reason (if any).
func (t *Task) Finish(reason any) { func (t *Task) Finish(reason any) {
t.once.Do(func() { t.mu.Lock()
if t.finishedCalled {
t.mu.Unlock()
return
}
t.finishedCalled = true
t.mu.Unlock()
t.finish(reason) t.finish(reason)
})
} }
func (t *Task) finish(reason any) { func (t *Task) finish(reason any) {
t.cancel(fmtCause(reason)) t.cancel(fmtCause(reason))
if !waitWithTimeout(&t.children) { if !waitWithTimeout(t.childrenDone) {
logger.Debug().
Strs("subtasks", t.listChildren()).
Msg("Timeout waiting for these subtasks to finish")
}
if !waitWithTimeout(&t.onFinished) {
logger.Debug(). logger.Debug().
Str("task", t.name). Str("task", t.name).
Strs("subtasks", t.listChildren()).
Msg("Timeout waiting for subtasks to finish")
}
go t.runCallbacks()
if !waitWithTimeout(t.callbacksDone) {
logger.Debug().
Str("task", t.name).
Strs("callbacks", t.listCallbacks()).
Msg("Timeout waiting for callbacks to finish") Msg("Timeout waiting for callbacks to finish")
} }
if t.finished != nil { if t.finished != nil {
close(t.finished) close(t.finished)
} }
if t == root {
return
}
t.parent.subChildCount()
allTasks.Remove(t)
logger.Trace().Msg("task " + t.name + " finished") logger.Trace().Msg("task " + t.name + " finished")
} }
// debug only.
func (t *Task) listChildren() []string {
var children []string
allTasks.Range(func(child *Task) bool {
if strings.HasPrefix(child.name, t.name+".") {
children = append(children, child.name)
}
return true
})
return children
}
func waitWithTimeout(wg *sync.WaitGroup) bool {
done := make(chan struct{})
timeout := time.After(taskTimeout)
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return true
case <-timeout:
return false
}
}
func fmtCause(cause any) error {
switch cause := cause.(type) {
case nil:
return nil
case error:
return cause
case string:
return errors.New(cause)
default:
return fmt.Errorf("%v", cause)
}
}
// Subtask returns a new subtask with the given name, derived from the parent's context. // Subtask returns a new subtask with the given name, derived from the parent's context.
// //
// This should not be called after Finish is called. // This should not be called after Finish is called.
@ -170,19 +135,19 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
ctx, cancel := context.WithCancelCause(t.ctx) ctx, cancel := context.WithCancelCause(t.ctx)
child := &Task{ child := &Task{
finished: make(chan struct{}, 1), parent: t,
finished: make(chan struct{}),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
if t != root { if t != root {
child.name = t.name + "." + name child.name = t.name + "." + name
allTasks.Add(child)
} else { } else {
child.name = name child.name = name
} }
allTasksWg.Add(1) allTasks.Add(child)
t.children.Add(1) t.addChildCount()
if !nf { if !nf {
go func() { go func() {
@ -191,13 +156,6 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
}() }()
} }
go func() {
<-child.finished
allTasksWg.Done()
t.children.Done()
allTasks.Remove(child)
}()
logger.Trace().Msg("task " + child.name + " started") logger.Trace().Msg("task " + child.name + " started")
return child return child
} }

View file

@ -10,7 +10,7 @@ import (
) )
func testTask() *Task { func testTask() *Task {
return RootTask("test", false) return RootTask("test", true)
} }
func TestChildTaskCancellation(t *testing.T) { func TestChildTaskCancellation(t *testing.T) {
@ -34,7 +34,7 @@ func TestChildTaskCancellation(t *testing.T) {
parent.cancel(nil) // should also cancel child parent.cancel(nil) // should also cancel child
select { select {
case <-child.Finished(): case <-child.Context().Done():
ExpectError(t, context.Canceled, child.Context().Err()) ExpectError(t, context.Canceled, child.Context().Err())
default: default:
t.Fatal("subTask context was not canceled as expected") t.Fatal("subTask context was not canceled as expected")

View file

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"slices" "slices"
"sync"
"time" "time"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
@ -19,23 +18,22 @@ var logger = logging.With().Str("module", "task").Logger()
var ( var (
root = newRoot() root = newRoot()
allTasks = F.NewSet[*Task]() allTasks = F.NewSet[*Task]()
allTasksWg sync.WaitGroup
) )
func testCleanup() { func testCleanup() {
root = newRoot() root = newRoot()
allTasks.Clear() allTasks.Clear()
allTasksWg = sync.WaitGroup{}
} }
// RootTask returns a new Task with the given name, derived from the root context. // RootTask returns a new Task with the given name, derived from the root context.
func RootTask(name string, needFinish bool) *Task { func RootTask(name string, needFinish ...bool) *Task {
return root.Subtask(name, needFinish) return root.Subtask(name, needFinish...)
} }
func newRoot() *Task { func newRoot() *Task {
t := &Task{name: "root"} t := &Task{name: "root"}
t.ctx, t.cancel = context.WithCancelCause(context.Background()) t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.callbacks = make(map[*Callback]struct{})
return t return t
} }
@ -57,19 +55,12 @@ func OnProgramExit(about string, fn func()) {
// still running when the timeout was reached, and their current tree // still running when the timeout was reached, and their current tree
// of subtasks. // of subtasks.
func GracefulShutdown(timeout time.Duration) (err error) { func GracefulShutdown(timeout time.Duration) (err error) {
root.cancel(ErrProgramExiting) go root.Finish(ErrProgramExiting)
done := make(chan struct{})
after := time.After(timeout) after := time.After(timeout)
go func() {
allTasksWg.Wait()
close(done)
}()
for { for {
select { select {
case <-done: case <-root.finished:
return return
case <-after: case <-after:
b, err := json.Marshal(DebugTaskList()) b, err := json.Marshal(DebugTaskList())