mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-01 21:14:24 +02:00
task package: replace waitgroup with channel, fix stuck
This commit is contained in:
parent
af14966b09
commit
2fe0b888bd
5 changed files with 152 additions and 102 deletions
26
internal/task/debug.go
Normal file
26
internal/task/debug.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package task
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// debug only.
|
||||||
|
func (t *Task) listChildren() []string {
|
||||||
|
var children []string
|
||||||
|
allTasks.Range(func(child *Task) bool {
|
||||||
|
if child.parent == t {
|
||||||
|
children = append(children, strings.TrimPrefix(child.name, t.name+"."))
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return children
|
||||||
|
}
|
||||||
|
|
||||||
|
// debug only.
|
||||||
|
func (t *Task) listCallbacks() []string {
|
||||||
|
var callbacks []string
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
for c := range t.callbacks {
|
||||||
|
callbacks = append(callbacks, c.about)
|
||||||
|
}
|
||||||
|
return callbacks
|
||||||
|
}
|
75
internal/task/impl.go
Normal file
75
internal/task/impl.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if t.callbacks == nil {
|
||||||
|
t.callbacks = make(map[*Callback]struct{})
|
||||||
|
t.callbacksDone = make(chan struct{})
|
||||||
|
}
|
||||||
|
t.callbacks[&Callback{fn, about, waitSubTasks}] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) addChildCount() {
|
||||||
|
if atomic.AddUint32(&t.children, 1) == 1 {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.childrenDone == nil {
|
||||||
|
t.childrenDone = make(chan struct{})
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) subChildCount() {
|
||||||
|
if atomic.AddUint32(&t.children, ^uint32(0)) == 0 {
|
||||||
|
close(t.childrenDone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) runCallbacks() {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if len(t.callbacks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for c := range t.callbacks {
|
||||||
|
if c.waitChildren {
|
||||||
|
waitWithTimeout(t.childrenDone)
|
||||||
|
}
|
||||||
|
t.invokeWithRecover(c.fn, c.about)
|
||||||
|
delete(t.callbacks, c)
|
||||||
|
}
|
||||||
|
close(t.callbacksDone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitWithTimeout(ch <-chan struct{}) bool {
|
||||||
|
if ch == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
return true
|
||||||
|
case <-time.After(taskTimeout):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fmtCause(cause any) error {
|
||||||
|
switch cause := cause.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case error:
|
||||||
|
return cause
|
||||||
|
case string:
|
||||||
|
return errors.New(cause)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%v", cause)
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,10 +2,7 @@ package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -26,6 +23,11 @@ type (
|
||||||
TaskFinisher interface {
|
TaskFinisher interface {
|
||||||
Finish(reason any)
|
Finish(reason any)
|
||||||
}
|
}
|
||||||
|
Callback struct {
|
||||||
|
fn func()
|
||||||
|
about string
|
||||||
|
waitChildren bool
|
||||||
|
}
|
||||||
// Task controls objects' lifetime.
|
// Task controls objects' lifetime.
|
||||||
//
|
//
|
||||||
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
|
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
|
||||||
|
@ -34,15 +36,19 @@ type (
|
||||||
Task struct {
|
Task struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
children sync.WaitGroup
|
parent *Task
|
||||||
|
children uint32
|
||||||
|
childrenDone chan struct{}
|
||||||
|
|
||||||
|
callbacks map[*Callback]struct{}
|
||||||
|
callbacksDone chan struct{}
|
||||||
|
|
||||||
onFinished sync.WaitGroup
|
|
||||||
finished chan struct{}
|
finished chan struct{}
|
||||||
|
finishedCalled bool
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelCauseFunc
|
cancel context.CancelCauseFunc
|
||||||
|
|
||||||
once sync.Once
|
|
||||||
}
|
}
|
||||||
Parent interface {
|
Parent interface {
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
|
@ -72,96 +78,55 @@ func (t *Task) FinishCause() error {
|
||||||
//
|
//
|
||||||
// It should not be called after Finish is called.
|
// It should not be called after Finish is called.
|
||||||
func (t *Task) OnFinished(about string, fn func()) {
|
func (t *Task) OnFinished(about string, fn func()) {
|
||||||
t.onCancel(about, fn, true)
|
t.addCallback(about, fn, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnCancel calls fn when the task is canceled.
|
// OnCancel calls fn when the task is canceled.
|
||||||
//
|
//
|
||||||
// It should not be called after Finish is called.
|
// It should not be called after Finish is called.
|
||||||
func (t *Task) OnCancel(about string, fn func()) {
|
func (t *Task) OnCancel(about string, fn func()) {
|
||||||
t.onCancel(about, fn, false)
|
t.addCallback(about, fn, false)
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
|
|
||||||
t.onFinished.Add(1)
|
|
||||||
go func() {
|
|
||||||
<-t.ctx.Done()
|
|
||||||
if waitSubTasks {
|
|
||||||
waitWithTimeout(&t.children)
|
|
||||||
}
|
|
||||||
t.invokeWithRecover(fn, about)
|
|
||||||
t.onFinished.Done()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish cancel all subtasks and wait for them to finish,
|
// Finish cancel all subtasks and wait for them to finish,
|
||||||
// then marks the task as finished, with the given reason (if any).
|
// then marks the task as finished, with the given reason (if any).
|
||||||
func (t *Task) Finish(reason any) {
|
func (t *Task) Finish(reason any) {
|
||||||
t.once.Do(func() {
|
t.mu.Lock()
|
||||||
|
if t.finishedCalled {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.finishedCalled = true
|
||||||
|
t.mu.Unlock()
|
||||||
t.finish(reason)
|
t.finish(reason)
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Task) finish(reason any) {
|
func (t *Task) finish(reason any) {
|
||||||
t.cancel(fmtCause(reason))
|
t.cancel(fmtCause(reason))
|
||||||
if !waitWithTimeout(&t.children) {
|
if !waitWithTimeout(t.childrenDone) {
|
||||||
logger.Debug().
|
|
||||||
Strs("subtasks", t.listChildren()).
|
|
||||||
Msg("Timeout waiting for these subtasks to finish")
|
|
||||||
}
|
|
||||||
if !waitWithTimeout(&t.onFinished) {
|
|
||||||
logger.Debug().
|
logger.Debug().
|
||||||
Str("task", t.name).
|
Str("task", t.name).
|
||||||
|
Strs("subtasks", t.listChildren()).
|
||||||
|
Msg("Timeout waiting for subtasks to finish")
|
||||||
|
}
|
||||||
|
go t.runCallbacks()
|
||||||
|
if !waitWithTimeout(t.callbacksDone) {
|
||||||
|
logger.Debug().
|
||||||
|
Str("task", t.name).
|
||||||
|
Strs("callbacks", t.listCallbacks()).
|
||||||
Msg("Timeout waiting for callbacks to finish")
|
Msg("Timeout waiting for callbacks to finish")
|
||||||
}
|
}
|
||||||
if t.finished != nil {
|
if t.finished != nil {
|
||||||
close(t.finished)
|
close(t.finished)
|
||||||
}
|
}
|
||||||
|
if t == root {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.parent.subChildCount()
|
||||||
|
allTasks.Remove(t)
|
||||||
logger.Trace().Msg("task " + t.name + " finished")
|
logger.Trace().Msg("task " + t.name + " finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
// debug only.
|
|
||||||
func (t *Task) listChildren() []string {
|
|
||||||
var children []string
|
|
||||||
allTasks.Range(func(child *Task) bool {
|
|
||||||
if strings.HasPrefix(child.name, t.name+".") {
|
|
||||||
children = append(children, child.name)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
return children
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitWithTimeout(wg *sync.WaitGroup) bool {
|
|
||||||
done := make(chan struct{})
|
|
||||||
timeout := time.After(taskTimeout)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return true
|
|
||||||
case <-timeout:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fmtCause(cause any) error {
|
|
||||||
switch cause := cause.(type) {
|
|
||||||
case nil:
|
|
||||||
return nil
|
|
||||||
case error:
|
|
||||||
return cause
|
|
||||||
case string:
|
|
||||||
return errors.New(cause)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%v", cause)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||||
//
|
//
|
||||||
// This should not be called after Finish is called.
|
// This should not be called after Finish is called.
|
||||||
|
@ -170,19 +135,19 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||||
child := &Task{
|
child := &Task{
|
||||||
finished: make(chan struct{}, 1),
|
parent: t,
|
||||||
|
finished: make(chan struct{}),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
if t != root {
|
if t != root {
|
||||||
child.name = t.name + "." + name
|
child.name = t.name + "." + name
|
||||||
allTasks.Add(child)
|
|
||||||
} else {
|
} else {
|
||||||
child.name = name
|
child.name = name
|
||||||
}
|
}
|
||||||
|
|
||||||
allTasksWg.Add(1)
|
allTasks.Add(child)
|
||||||
t.children.Add(1)
|
t.addChildCount()
|
||||||
|
|
||||||
if !nf {
|
if !nf {
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -191,13 +156,6 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-child.finished
|
|
||||||
allTasksWg.Done()
|
|
||||||
t.children.Done()
|
|
||||||
allTasks.Remove(child)
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Trace().Msg("task " + child.name + " started")
|
logger.Trace().Msg("task " + child.name + " started")
|
||||||
return child
|
return child
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func testTask() *Task {
|
func testTask() *Task {
|
||||||
return RootTask("test", false)
|
return RootTask("test", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChildTaskCancellation(t *testing.T) {
|
func TestChildTaskCancellation(t *testing.T) {
|
||||||
|
@ -34,7 +34,7 @@ func TestChildTaskCancellation(t *testing.T) {
|
||||||
parent.cancel(nil) // should also cancel child
|
parent.cancel(nil) // should also cancel child
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-child.Finished():
|
case <-child.Context().Done():
|
||||||
ExpectError(t, context.Canceled, child.Context().Err())
|
ExpectError(t, context.Canceled, child.Context().Err())
|
||||||
default:
|
default:
|
||||||
t.Fatal("subTask context was not canceled as expected")
|
t.Fatal("subTask context was not canceled as expected")
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
@ -19,23 +18,22 @@ var logger = logging.With().Str("module", "task").Logger()
|
||||||
var (
|
var (
|
||||||
root = newRoot()
|
root = newRoot()
|
||||||
allTasks = F.NewSet[*Task]()
|
allTasks = F.NewSet[*Task]()
|
||||||
allTasksWg sync.WaitGroup
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func testCleanup() {
|
func testCleanup() {
|
||||||
root = newRoot()
|
root = newRoot()
|
||||||
allTasks.Clear()
|
allTasks.Clear()
|
||||||
allTasksWg = sync.WaitGroup{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RootTask returns a new Task with the given name, derived from the root context.
|
// RootTask returns a new Task with the given name, derived from the root context.
|
||||||
func RootTask(name string, needFinish bool) *Task {
|
func RootTask(name string, needFinish ...bool) *Task {
|
||||||
return root.Subtask(name, needFinish)
|
return root.Subtask(name, needFinish...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRoot() *Task {
|
func newRoot() *Task {
|
||||||
t := &Task{name: "root"}
|
t := &Task{name: "root"}
|
||||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||||
|
t.callbacks = make(map[*Callback]struct{})
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,19 +55,12 @@ func OnProgramExit(about string, fn func()) {
|
||||||
// still running when the timeout was reached, and their current tree
|
// still running when the timeout was reached, and their current tree
|
||||||
// of subtasks.
|
// of subtasks.
|
||||||
func GracefulShutdown(timeout time.Duration) (err error) {
|
func GracefulShutdown(timeout time.Duration) (err error) {
|
||||||
root.cancel(ErrProgramExiting)
|
go root.Finish(ErrProgramExiting)
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
after := time.After(timeout)
|
after := time.After(timeout)
|
||||||
|
|
||||||
go func() {
|
|
||||||
allTasksWg.Wait()
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-root.finished:
|
||||||
return
|
return
|
||||||
case <-after:
|
case <-after:
|
||||||
b, err := json.Marshal(DebugTaskList())
|
b, err := json.Marshal(DebugTaskList())
|
||||||
|
|
Loading…
Add table
Reference in a new issue