GoDoxy/internal/task/task.go

246 lines
4.9 KiB
Go

package task
import (
"context"
"sync"
"time"
"github.com/yusing/go-proxy/internal/gperr"
)
type (
TaskStarter interface {
// Start starts the object that implements TaskStarter,
// and returns an error if it fails to start.
//
// callerSubtask.Finish must be called when start fails or the object is finished.
Start(parent Parent) gperr.Error
Task() *Task
}
TaskFinisher interface {
Finish(reason any)
}
Callback struct {
fn func()
about string
}
// Task controls objects' lifetime.
//
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
//
// Use Task.Finish to stop all subtasks of the Task.
Task struct {
parent *Task
name string
ctx context.Context
cancel context.CancelCauseFunc
done chan struct{}
finishCalled bool
onCancel *withWg[*Callback]
onFinish *withWg[*Callback]
children *withWg[*Task]
mu sync.Mutex
}
Parent interface {
Context() context.Context
// Subtask returns a new subtask with the given name, derived from the parent's context.
//
// This should not be called after Finish is called on the task or its parent task.
Subtask(name string, needFinish bool) *Task
Name() string
Finish(reason any)
OnCancel(name string, f func())
}
)
const taskTimeout = 3 * time.Second
func (t *Task) Context() context.Context {
return t.ctx
}
func (t *Task) Name() string {
return t.name
}
// String returns the full name of the task.
func (t *Task) String() string {
return t.fullName()
}
// MarshalText implements encoding.TextMarshaler.
func (t *Task) MarshalText() ([]byte, error) {
return []byte(t.fullName()), nil
}
// Finish marks the task as finished, with the given reason (if any).
func (t *Task) Finish(reason any) {
t.finish(reason, false)
}
// FinishCause returns the reason / error that caused the task to be finished.
func (t *Task) FinishCause() error {
return context.Cause(t.ctx)
}
// FinishAndWait cancel all subtasks and wait for them to finish,
// then marks the task as finished, with the given reason (if any).
func (t *Task) FinishAndWait(reason any) {
t.finish(reason, true)
}
// OnFinished calls fn when the task is canceled and all subtasks are finished.
//
// It should not be called after Finish is called.
func (t *Task) OnFinished(about string, fn func()) {
if !t.needFinish() {
t.OnCancel(about, fn)
return
}
t.mu.Lock()
if t.onFinish == nil {
t.onFinish = newWithWg[*Callback]()
t.mu.Unlock()
go func() {
<-t.ctx.Done()
<-t.done
for cb := range t.onFinish.Range {
go func(cb *Callback) {
invokeWithRecover(cb)
t.onFinish.Delete(cb)
}(cb)
}
}()
} else {
t.mu.Unlock()
}
t.onFinish.Add(&Callback{fn: fn, about: about})
}
// OnCancel calls fn when the task is canceled.
//
// It should not be called after Finish is called.
func (t *Task) OnCancel(about string, fn func()) {
t.mu.Lock()
if t.onCancel == nil {
t.onCancel = newWithWg[*Callback]()
t.mu.Unlock()
go func() {
<-t.ctx.Done()
for cb := range t.onCancel.Range {
go func(cb *Callback) {
invokeWithRecover(cb)
t.onCancel.Delete(cb)
}(cb)
}
}()
} else {
t.mu.Unlock()
}
t.onCancel.Add(&Callback{fn: fn, about: about})
}
// Subtask returns a new subtask with the given name, derived from the parent's context.
//
// This should not be called after Finish is called on the task or its parent task.
func (t *Task) Subtask(name string, needFinish bool) *Task {
t.mu.Lock()
if t.children == nil {
t.children = newWithWg[*Task]()
t.mu.Unlock()
} else {
t.mu.Unlock()
}
child := &Task{
name: name,
parent: t,
}
t.children.Add(child)
child.ctx, child.cancel = context.WithCancelCause(t.ctx)
if needFinish {
child.done = make(chan struct{})
} else {
child.done = closedCh
go func() {
<-child.ctx.Done()
child.Finish(t.FinishCause())
}()
}
logStarted(child)
return child
}
func (t *Task) finish(reason any, wait bool) {
t.mu.Lock()
if t.finishCalled {
t.mu.Unlock()
// wait but not report stucked (again)
t.waitFinish(taskTimeout)
return
}
t.finishCalled = true
t.mu.Unlock()
if t.needFinish() {
close(t.done)
}
t.cancel(fmtCause(reason))
if wait && !t.waitFinish(taskTimeout) {
t.reportStucked()
}
if t != root {
t.parent.children.Delete(t)
}
logFinished(t)
}
func (t *Task) waitFinish(timeout time.Duration) bool {
if t.children == nil && t.onCancel == nil && t.onFinish == nil {
return true
}
done := make(chan struct{})
go func() {
if t.children != nil {
t.children.Wait()
}
if t.onCancel != nil {
t.onCancel.Wait()
}
if t.onFinish != nil {
t.onFinish.Wait()
}
<-t.done
close(done)
}()
timeoutCh := time.After(timeout)
select {
case <-done:
return true
case <-timeoutCh:
return false
}
}
func (t *Task) fullName() string {
if t.parent == root {
return t.name
}
return t.parent.fullName() + "." + t.name
}
func (t *Task) needFinish() bool {
return t.done != closedCh
}