package task import ( "context" "errors" "os" "os/signal" "syscall" "time" "github.com/yusing/go-proxy/internal/logging" F "github.com/yusing/go-proxy/internal/utils/functional" ) var ErrProgramExiting = errors.New("program exiting") var ( root = newRoot() allTasks = F.NewSet[*Task]() ) func testCleanup() { root = newRoot() allTasks.Clear() } // RootTask returns a new Task with the given name, derived from the root context. func RootTask(name string, needFinish bool) *Task { return root.Subtask(name, needFinish) } func newRoot() *Task { t := &Task{ name: "root", childrenDone: make(chan struct{}), finished: make(chan struct{}), } t.ctx, t.cancel = context.WithCancelCause(context.Background()) return t } func RootContext() context.Context { return root.ctx } func RootContextCanceled() <-chan struct{} { return root.ctx.Done() } func OnProgramExit(about string, fn func()) { root.OnFinished(about, fn) } // GracefulShutdown waits for all tasks to finish, up to the given timeout. // // If the timeout is exceeded, it prints a list of all tasks that were // still running when the timeout was reached, and their current tree // of subtasks. func GracefulShutdown(timeout time.Duration) (err error) { go root.Finish(ErrProgramExiting) after := time.After(timeout) for { select { case <-root.finished: return case <-after: logging.Warn().Msgf("Timeout waiting for %d tasks to finish", allTasks.Size()) for t := range allTasks.Range { logging.Warn().Msgf("Task %s is still running", t.name) } return context.DeadlineExceeded } } } func WaitExit(shutdownTimeout int) { sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM) signal.Notify(sig, syscall.SIGHUP) // wait for signal <-sig // gracefully shutdown logging.Info().Msg("shutting down") _ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout)) }