diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 12c0125..92b80c5 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -177,7 +177,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { timer := time.NewTimer(time.Until(renewalTime)) defer timer.Stop() - task := parent.Subtask("cert-renew-scheduler") + task := parent.Subtask("cert-renew-scheduler", true) defer task.Finish(nil) for { diff --git a/internal/gperr/multiline.go b/internal/gperr/multiline.go index 103dd2b..35dfc97 100644 --- a/internal/gperr/multiline.go +++ b/internal/gperr/multiline.go @@ -46,3 +46,10 @@ func (m *MultilineError) AddLines(lines ...any) *MultilineError { } return m } + +func (m *MultilineError) AddLinesString(lines ...string) *MultilineError { + for _, line := range lines { + m.add(newError(line)) + } + return m +} diff --git a/internal/idlewatcher/watcher.go b/internal/idlewatcher/watcher.go index 4cf1139..54682ba 100644 --- a/internal/idlewatcher/watcher.go +++ b/internal/idlewatcher/watcher.go @@ -160,12 +160,12 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) { watcherMap[key] = w go func() { cause := w.watchUntilDestroy() - if cause.Is(causeContainerDestroy) || cause.Is(task.ErrProgramExiting) { + if errors.Is(cause, causeContainerDestroy) || errors.Is(cause, task.ErrProgramExiting) { watcherMapMu.Lock() defer watcherMapMu.Unlock() delete(watcherMap, key) w.l.Info().Msg("idlewatcher stopped") - } else if !cause.Is(causeReload) { + } else if !errors.Is(cause, causeReload) { gperr.LogError("idlewatcher stopped unexpectedly", cause, &w.l) } @@ -254,7 +254,7 @@ func (w *Watcher) expires() time.Time { // // it exits only if the context is canceled, the container is destroyed, // errors occurred on docker client, or route provider died (mainly caused by config reload). -func (w *Watcher) watchUntilDestroy() (returnCause gperr.Error) { +func (w *Watcher) watchUntilDestroy() (returnCause error) { eventCh, errCh := w.provider.Watch(w.Task().Context()) for { diff --git a/internal/maxmind/maxmind_test.go b/internal/maxmind/maxmind_test.go index 53416c4..6836c51 100644 --- a/internal/maxmind/maxmind_test.go +++ b/internal/maxmind/maxmind_test.go @@ -114,7 +114,7 @@ func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) { mockDataDir(t) mockMaxMindDBOpen(t) - task := task.RootTask("test") + task := task.RootTask("test", true) defer task.Finish(nil) err := cfg.LoadMaxMindDB(task) if err != nil { diff --git a/internal/metrics/period/poller.go b/internal/metrics/period/poller.go index b068c8f..3675258 100644 --- a/internal/metrics/period/poller.go +++ b/internal/metrics/period/poller.go @@ -134,7 +134,7 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) { } func (p *Poller[T, AggregateT]) Start() { - t := task.RootTask("poller." + p.name) + t := task.RootTask("poller."+p.name, true) l := log.With().Str("name", p.name).Logger() err := p.load() if err != nil { diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index 56a13d9..8d4846c 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -45,7 +45,7 @@ var maxRetries = map[zerolog.Level]int{ func StartNotifDispatcher(parent task.Parent) *Dispatcher { dispatcher = &Dispatcher{ - task: parent.Subtask("notification"), + task: parent.Subtask("notification", true), providers: F.NewSet[Provider](), logCh: make(chan *LogMessage), retryCh: make(chan *RetryMessage, 100), @@ -111,7 +111,7 @@ func (disp *Dispatcher) start() { } func (disp *Dispatcher) dispatch(msg *LogMessage) { - task := disp.task.Subtask("dispatcher") + task := disp.task.Subtask("dispatcher", true) defer task.Finish("notif dispatched") disp.providers.RangeAllParallel(func(p Provider) { @@ -126,7 +126,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) { } func (disp *Dispatcher) retry(messages []*RetryMessage) error { - task := disp.task.Subtask("retry") + task := disp.task.Subtask("retry", true) defer task.Finish("notif retried") errs := gperr.NewBuilder("notification failure") diff --git a/internal/route/stream.go b/internal/route/stream.go index 9c76de0..6be8af8 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -44,7 +44,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { if existing, ok := routes.Stream.Get(r.Key()); ok { return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName()) } - r.task = parent.Subtask("stream." + r.Name()) + r.task = parent.Subtask("stream."+r.Name(), true) r.Stream = NewStream(r) switch { diff --git a/internal/task/debug.go b/internal/task/debug.go index bf8a2db..6343cb0 100644 --- a/internal/task/debug.go +++ b/internal/task/debug.go @@ -1,43 +1,46 @@ package task import ( - "slices" - "strings" + "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/gperr" ) // debug only. -func (t *Task) listChildren() []string { - var children []string - allTasks.Range(func(child *Task) bool { - if child.parent == t { - children = append(children, strings.TrimPrefix(child.name, t.name+".")) - } - return true - }) - return children -} - -// debug only. -func (t *Task) listCallbacks() []string { - var callbacks []string - t.mu.Lock() - defer t.mu.Unlock() +func (t *Task) listStuckedCallbacks() []string { + callbacks := make([]string, 0, len(t.callbacks)) for c := range t.callbacks { - callbacks = append(callbacks, c.about) + if !c.done.Load() { + callbacks = append(callbacks, c.about) + } } return callbacks } -// DebugTaskList returns list of all tasks. -// -// The returned string is suitable for printing to the console. -func DebugTaskList() []string { - l := make([]string, 0, allTasks.Size()) - - allTasks.RangeAll(func(t *Task) { - l = append(l, t.name) - }) - - slices.Sort(l) - return l +// debug only. +func (t *Task) listStuckedChildren() []string { + children := make([]string, 0, len(t.children)) + for c := range t.children { + if c.isFinished() { + continue + } + children = append(children, c.String()) + if len(c.children) > 0 { + children = append(children, c.listStuckedChildren()...) + } + } + return children +} + +func (t *Task) reportStucked() { + callbacks := t.listStuckedCallbacks() + children := t.listStuckedChildren() + fmtOutput := gperr.Multiline(). + Addf("stucked callbacks: %d, stucked children: %d", + len(callbacks), len(children), + ). + Addf("callbacks"). + AddLinesString(callbacks...). + Addf("children"). + AddLinesString(children...) + log.Warn().Msg(fmtOutput.Error()) } diff --git a/internal/task/impl.go b/internal/task/impl.go index d6730d8..3428f2c 100644 --- a/internal/task/impl.go +++ b/internal/task/impl.go @@ -1,70 +1,217 @@ package task import ( + "context" "errors" "fmt" + "sync" "time" + _ "unsafe" ) +var ( + taskPool = make(chan *Task, 100) + + root = newRoot() +) + +func testCleanup() { + root = newRoot() +} + +func newRoot() *Task { + return newTask("root", nil, true) +} + +//go:inline +func newTask(name string, parent *Task, needFinish bool) *Task { + var t *Task + select { + case t = <-taskPool: + default: + t = &Task{} + } + t.name = name + t.parent = parent + if needFinish { + t.canceled = make(chan struct{}) + } else { + // it will not be nil, because root task always has a canceled channel + t.canceled = parent.canceled + } + return t +} + +func putTask(t *Task) { + select { + case taskPool <- t: + default: + return + } +} + +//go:inline +func (t *Task) setCause(cause error) { + if cause == nil { + t.cause = context.Canceled + } else { + t.cause = cause + } +} + +//go:inline func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) { t.mu.Lock() + if t.cause != nil { + t.mu.Unlock() + if waitSubTasks { + waitEmpty(t.children, taskTimeout) + } + fn() + return + } + defer t.mu.Unlock() if t.callbacks == nil { - t.callbacks = make(map[*Callback]struct{}) + t.callbacks = make(callbacksSet) } - if t.callbacksDone == nil { - t.callbacksDone = make(chan struct{}) - } - t.callbacks[&Callback{fn, about, waitSubTasks}] = struct{}{} + t.callbacks[&Callback{ + fn: fn, + about: about, + waitChildren: waitSubTasks, + }] = struct{}{} } -func (t *Task) addChildCount() { +//go:inline +func (t *Task) addChild(child *Task) { t.mu.Lock() - defer t.mu.Unlock() - t.children++ - if t.children == 1 { - t.childrenDone = make(chan struct{}) + if t.cause != nil { + t.mu.Unlock() + child.Finish(t.FinishCause()) + return } + + defer t.mu.Unlock() + + if t.children == nil { + t.children = make(childrenSet) + } + t.children[child] = struct{}{} } -func (t *Task) subChildCount() { +//go:inline +func (t *Task) removeChild(child *Task) { t.mu.Lock() defer t.mu.Unlock() - t.children-- - switch t.children { - case 0: - close(t.childrenDone) - case ^uint32(0): - panic("negative child count") + delete(t.children, child) +} + +func (t *Task) finishChildren() { + t.mu.Lock() + if len(t.children) == 0 { + t.mu.Unlock() + return } + + var wg sync.WaitGroup + for child := range t.children { + wg.Add(1) + go func() { + defer wg.Done() + child.Finish(t.cause) + }() + } + + clear(t.children) + t.mu.Unlock() + wg.Wait() } func (t *Task) runCallbacks() { + t.mu.Lock() if len(t.callbacks) == 0 { + t.mu.Unlock() return } + + var wg sync.WaitGroup + var needWait bool + + // runs callbacks that does not need wait first for c := range t.callbacks { - if c.waitChildren { - waitWithTimeout(t.childrenDone) + if !c.waitChildren { + wg.Add(1) + go func() { + defer wg.Done() + invokeWithRecover(c) + }() + } else { + needWait = true } - t.invokeWithRecover(c.fn, c.about) - delete(t.callbacks, c) } - close(t.callbacksDone) + + // runs callbacks that need to wait for children + if needWait { + waitEmpty(t.children, taskTimeout) + for c := range t.callbacks { + if c.waitChildren { + wg.Add(1) + go func() { + defer wg.Done() + invokeWithRecover(c) + }() + } + } + } + + clear(t.callbacks) + t.mu.Unlock() + wg.Wait() } -func waitWithTimeout(ch <-chan struct{}) bool { - if ch == nil { +func (t *Task) waitFinish(timeout time.Duration) bool { + // return directly if already finished + if t.isFinished() { return true } - select { - case <-ch: + + if len(t.children) == 0 && len(t.callbacks) == 0 { return true - case <-time.After(taskTimeout): + } + + ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout) + if !ok { return false } + t.finished.Store(true) + return true } +//go:inline +func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool { + if len(set) == 0 { + return true + } + + var sema uint32 + + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + if len(set) == 0 { + return true + } + select { + case <-timer.C: + return false + default: + runtime_Semacquire(&sema) + } + } +} + +//go:inline func fmtCause(cause any) error { switch cause := cause.(type) { case nil: @@ -77,3 +224,6 @@ func fmtCause(cause any) error { return fmt.Errorf("%v", cause) } } + +//go:linkname runtime_Semacquire sync.runtime_Semacquire +func runtime_Semacquire(s *uint32) diff --git a/internal/task/task.go b/internal/task/task.go index 9faa7a6..ebbb414 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -1,16 +1,17 @@ +// This file has the abstract logic of the task system. +// +// The implementation of the task system is in the impl.go file. package task import ( "context" - "runtime/debug" + "errors" "sync" "sync/atomic" "time" "github.com/rs/zerolog/log" - "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( @@ -29,6 +30,7 @@ type ( fn func() about string waitChildren bool + done atomic.Bool } // Task controls objects' lifetime. // @@ -38,46 +40,64 @@ type ( Task struct { name string - parent *Task - children uint32 - childrenDone chan struct{} + parent *Task + children childrenSet + callbacks callbacksSet - callbacks map[*Callback]struct{} - callbacksDone chan struct{} + cause error - finished chan struct{} - // finishedCalled == 1 Finish has been called - // but does not mean that the task is finished yet - // this is used to avoid calling Finish twice - finishedCalled uint32 + canceled chan struct{} - mu sync.Mutex - - ctx context.Context - cancel context.CancelCauseFunc + finished atomic.Bool + mu sync.Mutex } Parent interface { Context() context.Context - Subtask(name string, needFinish ...bool) *Task + Subtask(name string, needFinish bool) *Task Name() string Finish(reason any) OnCancel(name string, f func()) } ) +type ( + childrenSet = map[*Task]struct{} + callbacksSet = map[*Callback]struct{} +) + const taskTimeout = 3 * time.Second func (t *Task) Context() context.Context { - return t.ctx + return t } -func (t *Task) Finished() <-chan struct{} { - return t.finished +func (t *Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +func (t *Task) Done() <-chan struct{} { + return t.canceled +} + +func (t *Task) Err() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.cause == nil { + return context.Canceled + } + return t.cause +} + +func (t *Task) Value(_ any) any { + return nil } // FinishCause returns the reason / error that caused the task to be finished. func (t *Task) FinishCause() error { - return context.Cause(t.ctx) + if t.cause == nil || errors.Is(t.cause, context.Canceled) { + return nil + } + return t.cause } // OnFinished calls fn when the task is canceled and all subtasks are finished. @@ -97,105 +117,86 @@ func (t *Task) OnCancel(about string, fn func()) { // Finish cancel all subtasks and wait for them to finish, // then marks the task as finished, with the given reason (if any). func (t *Task) Finish(reason any) { - if atomic.LoadUint32(&t.finishedCalled) == 1 { - return - } - t.mu.Lock() - if t.finishedCalled == 1 { + if t.cause != nil { t.mu.Unlock() return } - - t.finishedCalled = 1 + cause := fmtCause(reason) + t.setCause(cause) + // t does not need finish, it shares the canceled channel with its parent + if t == root || t.canceled != t.parent.canceled { + close(t.canceled) + } t.mu.Unlock() - t.finish(reason) + t.finishAndWait() } -func (t *Task) finish(reason any) { - t.cancel(fmtCause(reason)) - if !waitWithTimeout(t.childrenDone) { - log.Debug(). - Str("task", t.name). - Strs("subtasks", t.listChildren()). - Msg("Timeout waiting for subtasks to finish") +func (t *Task) finishAndWait() { + defer putTask(t) + + t.finishChildren() + t.runCallbacks() + + if !t.waitFinish(taskTimeout) { + t.reportStucked() } - go t.runCallbacks() - if !waitWithTimeout(t.callbacksDone) { - log.Debug(). - Str("task", t.name). - Strs("callbacks", t.listCallbacks()). - Msg("Timeout waiting for callbacks to finish") + // clear anyway + clear(t.children) + clear(t.callbacks) + + if t != root { + t.parent.removeChild(t) } - close(t.finished) - if t == root { - return - } - t.parent.subChildCount() - allTasks.Remove(t) - log.Trace().Msg("task " + t.name + " finished") + logFinished(t) +} + +func (t *Task) isFinished() bool { + return t.finished.Load() } // Subtask returns a new subtask with the given name, derived from the parent's context. // -// This should not be called after Finish is called. -func (t *Task) Subtask(name string, needFinish ...bool) *Task { - nf := len(needFinish) == 0 || needFinish[0] +// This should not be called after Finish is called on the task or its parent task. +func (t *Task) Subtask(name string, needFinish bool) *Task { + panicIfFinished(t, "Subtask is called") - ctx, cancel := context.WithCancelCause(t.ctx) - child := &Task{ - parent: t, - finished: make(chan struct{}), - ctx: ctx, - cancel: cancel, - } - if t != root { - child.name = t.name + "." + name - } else { - child.name = name + child := newTask(name, t, needFinish) + + if needFinish { + t.addChild(child) } - allTasks.Add(child) - t.addChildCount() - - if !nf { - go func() { - <-child.ctx.Done() - child.Finish(nil) - }() - } - - log.Trace().Msg("task " + child.name + " started") + logStarted(child) return child } // Name returns the name of the task without parent names. func (t *Task) Name() string { - parts := strutils.SplitRune(t.name, '.') - return parts[len(parts)-1] + return t.name } // String returns the full name of the task. func (t *Task) String() string { + if t.parent != nil { + return t.parent.String() + "." + t.name + } return t.name } // MarshalText implements encoding.TextMarshaler. func (t *Task) MarshalText() ([]byte, error) { - return []byte(t.name), nil + return []byte(t.String()), nil } -func (t *Task) invokeWithRecover(fn func(), caller string) { +func invokeWithRecover(cb *Callback) { defer func() { + cb.done.Store(true) if err := recover(); err != nil { - log.Error(). - Interface("err", err). - Msg("panic in task " + t.name + "." + caller) - if common.IsDebug { - panic(string(debug.Stack())) - } + log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic") + panicWithDebugStack() } }() - fn() + cb.fn() } diff --git a/internal/task/task_debug.go b/internal/task/task_debug.go new file mode 100644 index 0000000..80a03dc --- /dev/null +++ b/internal/task/task_debug.go @@ -0,0 +1,27 @@ +//go:build debug + +package task + +import ( + "runtime/debug" + + "github.com/rs/zerolog/log" +) + +func panicWithDebugStack() { + panic(string(debug.Stack())) +} + +func panicIfFinished(t *Task, reason string) { + if t.isFinished() { + log.Panic().Msg("task " + t.String() + " is finished but " + reason) + } +} + +func logStarted(t *Task) { + log.Info().Msg("task " + t.String() + " started") +} + +func logFinished(t *Task) { + log.Info().Msg("task " + t.String() + " finished") +} diff --git a/internal/task/task_prod.go b/internal/task/task_prod.go new file mode 100644 index 0000000..4a07d2f --- /dev/null +++ b/internal/task/task_prod.go @@ -0,0 +1,19 @@ +//go:build !debug + +package task + +func panicWithDebugStack() { + // do nothing +} + +func panicIfFinished(t *Task, reason string) { + // do nothing +} + +func logStarted(t *Task) { + // do nothing +} + +func logFinished(t *Task) { + // do nothing +} diff --git a/internal/task/task_test.go b/internal/task/task_test.go index f9c35aa..889e5c6 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -17,7 +17,7 @@ func TestChildTaskCancellation(t *testing.T) { t.Cleanup(testCleanup) parent := testTask() - child := parent.Subtask("") + child := parent.Subtask("", true) go func() { defer child.Finish(nil) @@ -31,7 +31,7 @@ func TestChildTaskCancellation(t *testing.T) { } }() - parent.cancel(nil) // should also cancel child + parent.Finish(nil) // should also cancel child select { case <-child.Context().Done(): @@ -41,6 +41,31 @@ func TestChildTaskCancellation(t *testing.T) { } } +func TestTaskStuck(t *testing.T) { + t.Cleanup(testCleanup) + task := testTask() + task.OnCancel("second", func() { + time.Sleep(time.Second) + }) + done := make(chan struct{}) + go func() { + task.Finish(nil) + close(done) + }() + time.Sleep(time.Millisecond * 100) + select { + case <-done: + t.Fatal("task finished unexpectedly") + default: + } + time.Sleep(time.Second) + select { + case <-done: + default: + t.Fatal("task did not finish") + } +} + func TestTaskOnCancelOnFinished(t *testing.T) { t.Cleanup(testCleanup) task := testTask() @@ -83,11 +108,13 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) { } }() - ExpectNoError(t, GracefulShutdown(1*time.Second)) + ExpectNoError(t, gracefulShutdown(1*time.Second)) + time.Sleep(100 * time.Millisecond) ExpectTrue(t, finished) - <-root.finished - ExpectError(t, context.Canceled, task.Context().Err()) + ExpectTrue(t, root.waitFinish(1*time.Second)) + ExpectError(t, ErrProgramExiting, context.Cause(task.Context())) + ExpectError(t, ErrProgramExiting, task.Context().Err()) ExpectError(t, ErrProgramExiting, task.FinishCause()) } @@ -95,7 +122,7 @@ func TestTimeoutOnGracefulShutdown(t *testing.T) { t.Cleanup(testCleanup) _ = testTask() - ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond)) + ExpectError(t, context.DeadlineExceeded, gracefulShutdown(time.Millisecond)) } func TestFinishMultipleCalls(t *testing.T) { @@ -112,10 +139,26 @@ func TestFinishMultipleCalls(t *testing.T) { wg.Wait() } -func BenchmarkTasks(b *testing.B) { - for range b.N { +func BenchmarkTasksNoFinish(b *testing.B) { + for b.Loop() { + task := RootTask("", false) + task.Subtask("", false).Finish(nil) + task.Finish(nil) + } +} + +func BenchmarkTasksNeedFinish(b *testing.B) { + for b.Loop() { task := testTask() task.Subtask("", true).Finish(nil) task.Finish(nil) } } + +func BenchmarkContextWithCancel(b *testing.B) { + for b.Loop() { + task, taskCancel := context.WithCancel(b.Context()) + taskCancel() + <-task.Done() + } +} diff --git a/internal/task/utils.go b/internal/task/utils.go index e75b3be..d45fbee 100644 --- a/internal/task/utils.go +++ b/internal/task/utils.go @@ -2,7 +2,6 @@ package task import ( "context" - "encoding/json" "errors" "os" "os/signal" @@ -10,73 +9,34 @@ import ( "time" "github.com/rs/zerolog/log" - F "github.com/yusing/go-proxy/internal/utils/functional" ) var ErrProgramExiting = errors.New("program exiting") -var ( - root = newRoot() - allTasks = F.NewSet[*Task]() -) - -func testCleanup() { - root = newRoot() - allTasks.Clear() -} - // RootTask returns a new Task with the given name, derived from the root context. -func RootTask(name string, needFinish ...bool) *Task { - return root.Subtask(name, needFinish...) -} - -func newRoot() *Task { - t := &Task{ - name: "root", - childrenDone: make(chan struct{}), - finished: make(chan struct{}), - } - t.ctx, t.cancel = context.WithCancelCause(context.Background()) - return t +// +//go:inline +func RootTask(name string, needFinish bool) *Task { + return root.Subtask(name, needFinish) } func RootContext() context.Context { - return root.ctx + return root } func RootContextCanceled() <-chan struct{} { - return root.ctx.Done() + return root.Done() } func OnProgramExit(about string, fn func()) { root.OnFinished(about, fn) } -// GracefulShutdown waits for all tasks to finish, up to the given timeout. +// WaitExit waits for a signal to shutdown the program, and then waits for all tasks to finish, up to the given timeout. // // If the timeout is exceeded, it prints a list of all tasks that were // still running when the timeout was reached, and their current tree // of subtasks. -func GracefulShutdown(timeout time.Duration) (err error) { - go root.Finish(ErrProgramExiting) - - after := time.After(timeout) - for { - select { - case <-root.finished: - return - case <-after: - b, err := json.Marshal(DebugTaskList()) - if err != nil { - log.Warn().Err(err).Msg("failed to marshal tasks") - return context.DeadlineExceeded - } - log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size()) - return context.DeadlineExceeded - } - } -} - func WaitExit(shutdownTimeout int) { sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) @@ -88,5 +48,22 @@ func WaitExit(shutdownTimeout int) { // gracefully shutdown log.Info().Msg("shutting down") - _ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout)) + if err := gracefulShutdown(time.Second * time.Duration(shutdownTimeout)); err != nil { + root.reportStucked() + } +} + +// gracefulShutdown waits for all tasks to finish, up to the given timeout. +// +// If the timeout is exceeded, it prints a list of all tasks that were +// still running when the timeout was reached, and their current tree +// of subtasks. +func gracefulShutdown(timeout time.Duration) error { + root.Finish(ErrProgramExiting) + root.finishChildren() + root.runCallbacks() + if !root.waitFinish(timeout) { + return context.DeadlineExceeded + } + return nil } diff --git a/internal/utils/testing/log.go b/internal/utils/testing/log.go new file mode 100644 index 0000000..bf662db --- /dev/null +++ b/internal/utils/testing/log.go @@ -0,0 +1,9 @@ +package expect + +import "github.com/rs/zerolog" + +func init() { + if isTest { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } +} diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index 4bfee33..fc7cada 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -56,7 +56,7 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher { fwMap: make(map[string]*fileWatcher), eventCh: make(chan Event), errCh: make(chan gperr.Error), - task: parent.Subtask("dir_watcher(" + dirPath + ")"), + task: parent.Subtask("dir_watcher("+dirPath+")", true), } go helper.start() return helper