package task import ( "context" "encoding/json" "errors" "time" "github.com/yusing/go-proxy/internal/logging" F "github.com/yusing/go-proxy/internal/utils/functional" ) var ErrProgramExiting = errors.New("program exiting") var ( root = newRoot() allTasks = F.NewSet[*Task]() ) func testCleanup() { root = newRoot() allTasks.Clear() } // RootTask returns a new Task with the given name, derived from the root context. func RootTask(name string, needFinish ...bool) *Task { return root.Subtask(name, needFinish...) } func newRoot() *Task { t := &Task{ name: "root", childrenDone: make(chan struct{}), finished: make(chan struct{}), } t.ctx, t.cancel = context.WithCancelCause(context.Background()) return t } func RootContext() context.Context { return root.ctx } func RootContextCanceled() <-chan struct{} { return root.ctx.Done() } func OnProgramExit(about string, fn func()) { root.OnFinished(about, fn) } // GracefulShutdown waits for all tasks to finish, up to the given timeout. // // If the timeout is exceeded, it prints a list of all tasks that were // still running when the timeout was reached, and their current tree // of subtasks. func GracefulShutdown(timeout time.Duration) (err error) { go root.Finish(ErrProgramExiting) after := time.After(timeout) for { select { case <-root.finished: return case <-after: b, err := json.Marshal(DebugTaskList()) if err != nil { logging.Warn().Err(err).Msg("failed to marshal tasks") return context.DeadlineExceeded } logging.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size()) return context.DeadlineExceeded } } }