package task import ( "context" "sync" "time" "github.com/yusing/go-proxy/internal/gperr" ) type ( TaskStarter interface { // Start starts the object that implements TaskStarter, // and returns an error if it fails to start. // // callerSubtask.Finish must be called when start fails or the object is finished. Start(parent Parent) gperr.Error Task() *Task } TaskFinisher interface { Finish(reason any) } Callback struct { fn func() about string } // Task controls objects' lifetime. // // Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface. // // Use Task.Finish to stop all subtasks of the Task. Task struct { parent *Task name string ctx context.Context cancel context.CancelCauseFunc done chan struct{} finishCalled bool onCancel *withWg[*Callback] onFinish *withWg[*Callback] children *withWg[*Task] mu sync.Mutex } Parent interface { Context() context.Context // Subtask returns a new subtask with the given name, derived from the parent's context. // // This should not be called after Finish is called on the task or its parent task. Subtask(name string, needFinish bool) *Task Name() string Finish(reason any) OnCancel(name string, f func()) } ) const taskTimeout = 3 * time.Second func (t *Task) Context() context.Context { return t.ctx } func (t *Task) Name() string { return t.name } // String returns the full name of the task. func (t *Task) String() string { return t.fullName() } // MarshalText implements encoding.TextMarshaler. func (t *Task) MarshalText() ([]byte, error) { return []byte(t.fullName()), nil } // Finish marks the task as finished, with the given reason (if any). func (t *Task) Finish(reason any) { t.finish(reason, false) } // FinishCause returns the reason / error that caused the task to be finished. func (t *Task) FinishCause() error { return context.Cause(t.ctx) } // FinishAndWait cancel all subtasks and wait for them to finish, // then marks the task as finished, with the given reason (if any). func (t *Task) FinishAndWait(reason any) { t.finish(reason, true) } // OnFinished calls fn when the task is canceled and all subtasks are finished. // // It should not be called after Finish is called. func (t *Task) OnFinished(about string, fn func()) { if !t.needFinish() { t.OnCancel(about, fn) return } t.mu.Lock() if t.onFinish == nil { t.onFinish = newWithWg[*Callback]() t.mu.Unlock() go func() { <-t.ctx.Done() <-t.done for cb := range t.onFinish.Range { go func(cb *Callback) { invokeWithRecover(cb) t.onFinish.Delete(cb) }(cb) } }() } else { t.mu.Unlock() } t.onFinish.Add(&Callback{fn: fn, about: about}) } // OnCancel calls fn when the task is canceled. // // It should not be called after Finish is called. func (t *Task) OnCancel(about string, fn func()) { t.mu.Lock() if t.onCancel == nil { t.onCancel = newWithWg[*Callback]() t.mu.Unlock() go func() { <-t.ctx.Done() for cb := range t.onCancel.Range { go func(cb *Callback) { invokeWithRecover(cb) t.onCancel.Delete(cb) }(cb) } }() } else { t.mu.Unlock() } t.onCancel.Add(&Callback{fn: fn, about: about}) } // Subtask returns a new subtask with the given name, derived from the parent's context. // // This should not be called after Finish is called on the task or its parent task. func (t *Task) Subtask(name string, needFinish bool) *Task { t.mu.Lock() if t.children == nil { t.children = newWithWg[*Task]() t.mu.Unlock() } else { t.mu.Unlock() } child := &Task{ name: name, parent: t, } t.children.Add(child) child.ctx, child.cancel = context.WithCancelCause(t.ctx) if needFinish { child.done = make(chan struct{}) } else { child.done = closedCh go func() { <-child.ctx.Done() child.Finish(t.FinishCause()) }() } logStarted(child) return child } func (t *Task) finish(reason any, wait bool) { t.mu.Lock() if t.finishCalled { t.mu.Unlock() // wait but not report stucked (again) t.waitFinish(taskTimeout) return } t.finishCalled = true t.mu.Unlock() if t.needFinish() { close(t.done) } t.cancel(fmtCause(reason)) if wait && !t.waitFinish(taskTimeout) { t.reportStucked() } if t != root { t.parent.children.Delete(t) } logFinished(t) } func (t *Task) waitFinish(timeout time.Duration) bool { if t.children == nil && t.onCancel == nil && t.onFinish == nil { return true } done := make(chan struct{}) go func() { if t.children != nil { t.children.Wait() } if t.onCancel != nil { t.onCancel.Wait() } if t.onFinish != nil { t.onFinish.Wait() } <-t.done close(done) }() timeoutCh := time.After(timeout) select { case <-done: return true case <-timeoutCh: return false } } func (t *Task) fullName() string { if t.parent == root { return t.name } return t.parent.fullName() + "." + t.name } func (t *Task) needFinish() bool { return t.done != closedCh }