GoDoxy/internal/task/utils.go
2025-01-01 14:25:44 +08:00

98 lines
2 KiB
Go

package task
import (
"context"
"encoding/json"
"errors"
"slices"
"sync"
"time"
"github.com/yusing/go-proxy/internal/logging"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
var ErrProgramExiting = errors.New("program exiting")
var logger = logging.With().Str("module", "task").Logger()
var (
root = newRoot()
allTasks = F.NewSet[*Task]()
allTasksWg sync.WaitGroup
)
func testCleanup() {
root = newRoot()
allTasks.Clear()
allTasksWg = sync.WaitGroup{}
}
// RootTask returns a new Task with the given name, derived from the root context.
func RootTask(name string, needFinish bool) *Task {
return root.Subtask(name, needFinish)
}
func newRoot() *Task {
t := &Task{name: "root"}
t.ctx, t.cancel = context.WithCancelCause(context.Background())
return t
}
func RootContext() context.Context {
return root.ctx
}
func RootContextCanceled() <-chan struct{} {
return root.ctx.Done()
}
func OnProgramExit(about string, fn func()) {
root.OnFinished(about, fn)
}
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GracefulShutdown(timeout time.Duration) (err error) {
root.cancel(ErrProgramExiting)
done := make(chan struct{})
after := time.After(timeout)
go func() {
allTasksWg.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
b, err := json.Marshal(DebugTaskList())
if err != nil {
logger.Warn().Err(err).Msg("failed to marshal tasks")
return context.DeadlineExceeded
}
logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
return context.DeadlineExceeded
}
}
}
// DebugTaskList returns list of all tasks.
//
// The returned string is suitable for printing to the console.
func DebugTaskList() []string {
l := make([]string, 0, allTasks.Size())
allTasks.RangeAll(func(t *Task) {
l = append(l, t.name)
})
slices.Sort(l)
return l
}