mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-30 16:42:35 +02:00
refactor: improve task management with xsync for concurrent access and enhance callback and subtasks handling as well as memory allocation
This commit is contained in:
commit
b163771956
16 changed files with 447 additions and 211 deletions
|
@ -177,7 +177,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
|||
timer := time.NewTimer(time.Until(renewalTime))
|
||||
defer timer.Stop()
|
||||
|
||||
task := parent.Subtask("cert-renew-scheduler")
|
||||
task := parent.Subtask("cert-renew-scheduler", true)
|
||||
defer task.Finish(nil)
|
||||
|
||||
for {
|
||||
|
|
|
@ -46,3 +46,10 @@ func (m *MultilineError) AddLines(lines ...any) *MultilineError {
|
|||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MultilineError) AddLinesString(lines ...string) *MultilineError {
|
||||
for _, line := range lines {
|
||||
m.add(newError(line))
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
|
|
@ -160,12 +160,12 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
|||
watcherMap[key] = w
|
||||
go func() {
|
||||
cause := w.watchUntilDestroy()
|
||||
if cause.Is(causeContainerDestroy) || cause.Is(task.ErrProgramExiting) {
|
||||
if errors.Is(cause, causeContainerDestroy) || errors.Is(cause, task.ErrProgramExiting) {
|
||||
watcherMapMu.Lock()
|
||||
defer watcherMapMu.Unlock()
|
||||
delete(watcherMap, key)
|
||||
w.l.Info().Msg("idlewatcher stopped")
|
||||
} else if !cause.Is(causeReload) {
|
||||
} else if !errors.Is(cause, causeReload) {
|
||||
gperr.LogError("idlewatcher stopped unexpectedly", cause, &w.l)
|
||||
}
|
||||
|
||||
|
@ -254,7 +254,7 @@ func (w *Watcher) expires() time.Time {
|
|||
//
|
||||
// it exits only if the context is canceled, the container is destroyed,
|
||||
// errors occurred on docker client, or route provider died (mainly caused by config reload).
|
||||
func (w *Watcher) watchUntilDestroy() (returnCause gperr.Error) {
|
||||
func (w *Watcher) watchUntilDestroy() (returnCause error) {
|
||||
eventCh, errCh := w.provider.Watch(w.Task().Context())
|
||||
|
||||
for {
|
||||
|
|
|
@ -114,7 +114,7 @@ func Test_MaxMindConfig_loadMaxMindDB(t *testing.T) {
|
|||
mockDataDir(t)
|
||||
mockMaxMindDBOpen(t)
|
||||
|
||||
task := task.RootTask("test")
|
||||
task := task.RootTask("test", true)
|
||||
defer task.Finish(nil)
|
||||
err := cfg.LoadMaxMindDB(task)
|
||||
if err != nil {
|
||||
|
|
|
@ -134,7 +134,7 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
|
|||
}
|
||||
|
||||
func (p *Poller[T, AggregateT]) Start() {
|
||||
t := task.RootTask("poller." + p.name)
|
||||
t := task.RootTask("poller."+p.name, true)
|
||||
l := log.With().Str("name", p.name).Logger()
|
||||
err := p.load()
|
||||
if err != nil {
|
||||
|
|
|
@ -45,7 +45,7 @@ var maxRetries = map[zerolog.Level]int{
|
|||
|
||||
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
|
||||
dispatcher = &Dispatcher{
|
||||
task: parent.Subtask("notification"),
|
||||
task: parent.Subtask("notification", true),
|
||||
providers: F.NewSet[Provider](),
|
||||
logCh: make(chan *LogMessage),
|
||||
retryCh: make(chan *RetryMessage, 100),
|
||||
|
@ -111,7 +111,7 @@ func (disp *Dispatcher) start() {
|
|||
}
|
||||
|
||||
func (disp *Dispatcher) dispatch(msg *LogMessage) {
|
||||
task := disp.task.Subtask("dispatcher")
|
||||
task := disp.task.Subtask("dispatcher", true)
|
||||
defer task.Finish("notif dispatched")
|
||||
|
||||
disp.providers.RangeAllParallel(func(p Provider) {
|
||||
|
@ -126,7 +126,7 @@ func (disp *Dispatcher) dispatch(msg *LogMessage) {
|
|||
}
|
||||
|
||||
func (disp *Dispatcher) retry(messages []*RetryMessage) error {
|
||||
task := disp.task.Subtask("retry")
|
||||
task := disp.task.Subtask("retry", true)
|
||||
defer task.Finish("notif retried")
|
||||
|
||||
errs := gperr.NewBuilder("notification failure")
|
||||
|
|
|
@ -44,7 +44,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
|||
if existing, ok := routes.Stream.Get(r.Key()); ok {
|
||||
return gperr.Errorf("route already exists: from provider %s and %s", existing.ProviderName(), r.ProviderName())
|
||||
}
|
||||
r.task = parent.Subtask("stream." + r.Name())
|
||||
r.task = parent.Subtask("stream."+r.Name(), true)
|
||||
r.Stream = NewStream(r)
|
||||
|
||||
switch {
|
||||
|
|
|
@ -1,43 +1,46 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
)
|
||||
|
||||
// debug only.
|
||||
func (t *Task) listChildren() []string {
|
||||
var children []string
|
||||
allTasks.Range(func(child *Task) bool {
|
||||
if child.parent == t {
|
||||
children = append(children, strings.TrimPrefix(child.name, t.name+"."))
|
||||
}
|
||||
return true
|
||||
})
|
||||
return children
|
||||
}
|
||||
|
||||
// debug only.
|
||||
func (t *Task) listCallbacks() []string {
|
||||
var callbacks []string
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
func (t *Task) listStuckedCallbacks() []string {
|
||||
callbacks := make([]string, 0, len(t.callbacks))
|
||||
for c := range t.callbacks {
|
||||
callbacks = append(callbacks, c.about)
|
||||
if !c.done.Load() {
|
||||
callbacks = append(callbacks, c.about)
|
||||
}
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
|
||||
// DebugTaskList returns list of all tasks.
|
||||
//
|
||||
// The returned string is suitable for printing to the console.
|
||||
func DebugTaskList() []string {
|
||||
l := make([]string, 0, allTasks.Size())
|
||||
|
||||
allTasks.RangeAll(func(t *Task) {
|
||||
l = append(l, t.name)
|
||||
})
|
||||
|
||||
slices.Sort(l)
|
||||
return l
|
||||
// debug only.
|
||||
func (t *Task) listStuckedChildren() []string {
|
||||
children := make([]string, 0, len(t.children))
|
||||
for c := range t.children {
|
||||
if c.isFinished() {
|
||||
continue
|
||||
}
|
||||
children = append(children, c.String())
|
||||
if len(c.children) > 0 {
|
||||
children = append(children, c.listStuckedChildren()...)
|
||||
}
|
||||
}
|
||||
return children
|
||||
}
|
||||
|
||||
func (t *Task) reportStucked() {
|
||||
callbacks := t.listStuckedCallbacks()
|
||||
children := t.listStuckedChildren()
|
||||
fmtOutput := gperr.Multiline().
|
||||
Addf("stucked callbacks: %d, stucked children: %d",
|
||||
len(callbacks), len(children),
|
||||
).
|
||||
Addf("callbacks").
|
||||
AddLinesString(callbacks...).
|
||||
Addf("children").
|
||||
AddLinesString(children...)
|
||||
log.Warn().Msg(fmtOutput.Error())
|
||||
}
|
||||
|
|
|
@ -1,70 +1,217 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
taskPool = make(chan *Task, 100)
|
||||
|
||||
root = newRoot()
|
||||
)
|
||||
|
||||
func testCleanup() {
|
||||
root = newRoot()
|
||||
}
|
||||
|
||||
func newRoot() *Task {
|
||||
return newTask("root", nil, true)
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func newTask(name string, parent *Task, needFinish bool) *Task {
|
||||
var t *Task
|
||||
select {
|
||||
case t = <-taskPool:
|
||||
default:
|
||||
t = &Task{}
|
||||
}
|
||||
t.name = name
|
||||
t.parent = parent
|
||||
if needFinish {
|
||||
t.canceled = make(chan struct{})
|
||||
} else {
|
||||
// it will not be nil, because root task always has a canceled channel
|
||||
t.canceled = parent.canceled
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func putTask(t *Task) {
|
||||
select {
|
||||
case taskPool <- t:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func (t *Task) setCause(cause error) {
|
||||
if cause == nil {
|
||||
t.cause = context.Canceled
|
||||
} else {
|
||||
t.cause = cause
|
||||
}
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
|
||||
t.mu.Lock()
|
||||
if t.cause != nil {
|
||||
t.mu.Unlock()
|
||||
if waitSubTasks {
|
||||
waitEmpty(t.children, taskTimeout)
|
||||
}
|
||||
fn()
|
||||
return
|
||||
}
|
||||
|
||||
defer t.mu.Unlock()
|
||||
if t.callbacks == nil {
|
||||
t.callbacks = make(map[*Callback]struct{})
|
||||
t.callbacks = make(callbacksSet)
|
||||
}
|
||||
if t.callbacksDone == nil {
|
||||
t.callbacksDone = make(chan struct{})
|
||||
}
|
||||
t.callbacks[&Callback{fn, about, waitSubTasks}] = struct{}{}
|
||||
t.callbacks[&Callback{
|
||||
fn: fn,
|
||||
about: about,
|
||||
waitChildren: waitSubTasks,
|
||||
}] = struct{}{}
|
||||
}
|
||||
|
||||
func (t *Task) addChildCount() {
|
||||
//go:inline
|
||||
func (t *Task) addChild(child *Task) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.children++
|
||||
if t.children == 1 {
|
||||
t.childrenDone = make(chan struct{})
|
||||
if t.cause != nil {
|
||||
t.mu.Unlock()
|
||||
child.Finish(t.FinishCause())
|
||||
return
|
||||
}
|
||||
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.children == nil {
|
||||
t.children = make(childrenSet)
|
||||
}
|
||||
t.children[child] = struct{}{}
|
||||
}
|
||||
|
||||
func (t *Task) subChildCount() {
|
||||
//go:inline
|
||||
func (t *Task) removeChild(child *Task) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.children--
|
||||
switch t.children {
|
||||
case 0:
|
||||
close(t.childrenDone)
|
||||
case ^uint32(0):
|
||||
panic("negative child count")
|
||||
delete(t.children, child)
|
||||
}
|
||||
|
||||
func (t *Task) finishChildren() {
|
||||
t.mu.Lock()
|
||||
if len(t.children) == 0 {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for child := range t.children {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
child.Finish(t.cause)
|
||||
}()
|
||||
}
|
||||
|
||||
clear(t.children)
|
||||
t.mu.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (t *Task) runCallbacks() {
|
||||
t.mu.Lock()
|
||||
if len(t.callbacks) == 0 {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var needWait bool
|
||||
|
||||
// runs callbacks that does not need wait first
|
||||
for c := range t.callbacks {
|
||||
if c.waitChildren {
|
||||
waitWithTimeout(t.childrenDone)
|
||||
if !c.waitChildren {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
invokeWithRecover(c)
|
||||
}()
|
||||
} else {
|
||||
needWait = true
|
||||
}
|
||||
t.invokeWithRecover(c.fn, c.about)
|
||||
delete(t.callbacks, c)
|
||||
}
|
||||
close(t.callbacksDone)
|
||||
|
||||
// runs callbacks that need to wait for children
|
||||
if needWait {
|
||||
waitEmpty(t.children, taskTimeout)
|
||||
for c := range t.callbacks {
|
||||
if c.waitChildren {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
invokeWithRecover(c)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clear(t.callbacks)
|
||||
t.mu.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func waitWithTimeout(ch <-chan struct{}) bool {
|
||||
if ch == nil {
|
||||
func (t *Task) waitFinish(timeout time.Duration) bool {
|
||||
// return directly if already finished
|
||||
if t.isFinished() {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ch:
|
||||
|
||||
if len(t.children) == 0 && len(t.callbacks) == 0 {
|
||||
return true
|
||||
case <-time.After(taskTimeout):
|
||||
}
|
||||
|
||||
ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
t.finished.Store(true)
|
||||
return true
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
|
||||
if len(set) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
var sema uint32
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
if len(set) == 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-timer.C:
|
||||
return false
|
||||
default:
|
||||
runtime_Semacquire(&sema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//go:inline
|
||||
func fmtCause(cause any) error {
|
||||
switch cause := cause.(type) {
|
||||
case nil:
|
||||
|
@ -77,3 +224,6 @@ func fmtCause(cause any) error {
|
|||
return fmt.Errorf("%v", cause)
|
||||
}
|
||||
}
|
||||
|
||||
//go:linkname runtime_Semacquire sync.runtime_Semacquire
|
||||
func runtime_Semacquire(s *uint32)
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
// This file has the abstract logic of the task system.
|
||||
//
|
||||
// The implementation of the task system is in the impl.go file.
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -29,6 +30,7 @@ type (
|
|||
fn func()
|
||||
about string
|
||||
waitChildren bool
|
||||
done atomic.Bool
|
||||
}
|
||||
// Task controls objects' lifetime.
|
||||
//
|
||||
|
@ -38,46 +40,64 @@ type (
|
|||
Task struct {
|
||||
name string
|
||||
|
||||
parent *Task
|
||||
children uint32
|
||||
childrenDone chan struct{}
|
||||
parent *Task
|
||||
children childrenSet
|
||||
callbacks callbacksSet
|
||||
|
||||
callbacks map[*Callback]struct{}
|
||||
callbacksDone chan struct{}
|
||||
cause error
|
||||
|
||||
finished chan struct{}
|
||||
// finishedCalled == 1 Finish has been called
|
||||
// but does not mean that the task is finished yet
|
||||
// this is used to avoid calling Finish twice
|
||||
finishedCalled uint32
|
||||
canceled chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
finished atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
Parent interface {
|
||||
Context() context.Context
|
||||
Subtask(name string, needFinish ...bool) *Task
|
||||
Subtask(name string, needFinish bool) *Task
|
||||
Name() string
|
||||
Finish(reason any)
|
||||
OnCancel(name string, f func())
|
||||
}
|
||||
)
|
||||
|
||||
type (
|
||||
childrenSet = map[*Task]struct{}
|
||||
callbacksSet = map[*Callback]struct{}
|
||||
)
|
||||
|
||||
const taskTimeout = 3 * time.Second
|
||||
|
||||
func (t *Task) Context() context.Context {
|
||||
return t.ctx
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Task) Finished() <-chan struct{} {
|
||||
return t.finished
|
||||
func (t *Task) Deadline() (time.Time, bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func (t *Task) Done() <-chan struct{} {
|
||||
return t.canceled
|
||||
}
|
||||
|
||||
func (t *Task) Err() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.cause == nil {
|
||||
return context.Canceled
|
||||
}
|
||||
return t.cause
|
||||
}
|
||||
|
||||
func (t *Task) Value(_ any) any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishCause returns the reason / error that caused the task to be finished.
|
||||
func (t *Task) FinishCause() error {
|
||||
return context.Cause(t.ctx)
|
||||
if t.cause == nil || errors.Is(t.cause, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return t.cause
|
||||
}
|
||||
|
||||
// OnFinished calls fn when the task is canceled and all subtasks are finished.
|
||||
|
@ -97,105 +117,86 @@ func (t *Task) OnCancel(about string, fn func()) {
|
|||
// Finish cancel all subtasks and wait for them to finish,
|
||||
// then marks the task as finished, with the given reason (if any).
|
||||
func (t *Task) Finish(reason any) {
|
||||
if atomic.LoadUint32(&t.finishedCalled) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if t.finishedCalled == 1 {
|
||||
if t.cause != nil {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
t.finishedCalled = 1
|
||||
cause := fmtCause(reason)
|
||||
t.setCause(cause)
|
||||
// t does not need finish, it shares the canceled channel with its parent
|
||||
if t == root || t.canceled != t.parent.canceled {
|
||||
close(t.canceled)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
t.finish(reason)
|
||||
t.finishAndWait()
|
||||
}
|
||||
|
||||
func (t *Task) finish(reason any) {
|
||||
t.cancel(fmtCause(reason))
|
||||
if !waitWithTimeout(t.childrenDone) {
|
||||
log.Debug().
|
||||
Str("task", t.name).
|
||||
Strs("subtasks", t.listChildren()).
|
||||
Msg("Timeout waiting for subtasks to finish")
|
||||
func (t *Task) finishAndWait() {
|
||||
defer putTask(t)
|
||||
|
||||
t.finishChildren()
|
||||
t.runCallbacks()
|
||||
|
||||
if !t.waitFinish(taskTimeout) {
|
||||
t.reportStucked()
|
||||
}
|
||||
go t.runCallbacks()
|
||||
if !waitWithTimeout(t.callbacksDone) {
|
||||
log.Debug().
|
||||
Str("task", t.name).
|
||||
Strs("callbacks", t.listCallbacks()).
|
||||
Msg("Timeout waiting for callbacks to finish")
|
||||
// clear anyway
|
||||
clear(t.children)
|
||||
clear(t.callbacks)
|
||||
|
||||
if t != root {
|
||||
t.parent.removeChild(t)
|
||||
}
|
||||
close(t.finished)
|
||||
if t == root {
|
||||
return
|
||||
}
|
||||
t.parent.subChildCount()
|
||||
allTasks.Remove(t)
|
||||
log.Trace().Msg("task " + t.name + " finished")
|
||||
logFinished(t)
|
||||
}
|
||||
|
||||
func (t *Task) isFinished() bool {
|
||||
return t.finished.Load()
|
||||
}
|
||||
|
||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||
//
|
||||
// This should not be called after Finish is called.
|
||||
func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||
nf := len(needFinish) == 0 || needFinish[0]
|
||||
// This should not be called after Finish is called on the task or its parent task.
|
||||
func (t *Task) Subtask(name string, needFinish bool) *Task {
|
||||
panicIfFinished(t, "Subtask is called")
|
||||
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
child := &Task{
|
||||
parent: t,
|
||||
finished: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
if t != root {
|
||||
child.name = t.name + "." + name
|
||||
} else {
|
||||
child.name = name
|
||||
child := newTask(name, t, needFinish)
|
||||
|
||||
if needFinish {
|
||||
t.addChild(child)
|
||||
}
|
||||
|
||||
allTasks.Add(child)
|
||||
t.addChildCount()
|
||||
|
||||
if !nf {
|
||||
go func() {
|
||||
<-child.ctx.Done()
|
||||
child.Finish(nil)
|
||||
}()
|
||||
}
|
||||
|
||||
log.Trace().Msg("task " + child.name + " started")
|
||||
logStarted(child)
|
||||
return child
|
||||
}
|
||||
|
||||
// Name returns the name of the task without parent names.
|
||||
func (t *Task) Name() string {
|
||||
parts := strutils.SplitRune(t.name, '.')
|
||||
return parts[len(parts)-1]
|
||||
return t.name
|
||||
}
|
||||
|
||||
// String returns the full name of the task.
|
||||
func (t *Task) String() string {
|
||||
if t.parent != nil {
|
||||
return t.parent.String() + "." + t.name
|
||||
}
|
||||
return t.name
|
||||
}
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (t *Task) MarshalText() ([]byte, error) {
|
||||
return []byte(t.name), nil
|
||||
return []byte(t.String()), nil
|
||||
}
|
||||
|
||||
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
||||
func invokeWithRecover(cb *Callback) {
|
||||
defer func() {
|
||||
cb.done.Store(true)
|
||||
if err := recover(); err != nil {
|
||||
log.Error().
|
||||
Interface("err", err).
|
||||
Msg("panic in task " + t.name + "." + caller)
|
||||
if common.IsDebug {
|
||||
panic(string(debug.Stack()))
|
||||
}
|
||||
log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic")
|
||||
panicWithDebugStack()
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
cb.fn()
|
||||
}
|
||||
|
|
27
internal/task/task_debug.go
Normal file
27
internal/task/task_debug.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
//go:build debug
|
||||
|
||||
package task
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func panicWithDebugStack() {
|
||||
panic(string(debug.Stack()))
|
||||
}
|
||||
|
||||
func panicIfFinished(t *Task, reason string) {
|
||||
if t.isFinished() {
|
||||
log.Panic().Msg("task " + t.String() + " is finished but " + reason)
|
||||
}
|
||||
}
|
||||
|
||||
func logStarted(t *Task) {
|
||||
log.Info().Msg("task " + t.String() + " started")
|
||||
}
|
||||
|
||||
func logFinished(t *Task) {
|
||||
log.Info().Msg("task " + t.String() + " finished")
|
||||
}
|
19
internal/task/task_prod.go
Normal file
19
internal/task/task_prod.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
//go:build !debug
|
||||
|
||||
package task
|
||||
|
||||
func panicWithDebugStack() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func panicIfFinished(t *Task, reason string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func logStarted(t *Task) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func logFinished(t *Task) {
|
||||
// do nothing
|
||||
}
|
|
@ -17,7 +17,7 @@ func TestChildTaskCancellation(t *testing.T) {
|
|||
t.Cleanup(testCleanup)
|
||||
|
||||
parent := testTask()
|
||||
child := parent.Subtask("")
|
||||
child := parent.Subtask("", true)
|
||||
|
||||
go func() {
|
||||
defer child.Finish(nil)
|
||||
|
@ -31,7 +31,7 @@ func TestChildTaskCancellation(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
parent.cancel(nil) // should also cancel child
|
||||
parent.Finish(nil) // should also cancel child
|
||||
|
||||
select {
|
||||
case <-child.Context().Done():
|
||||
|
@ -41,6 +41,31 @@ func TestChildTaskCancellation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTaskStuck(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
task.OnCancel("second", func() {
|
||||
time.Sleep(time.Second)
|
||||
})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
task.Finish(nil)
|
||||
close(done)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("task finished unexpectedly")
|
||||
default:
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
t.Fatal("task did not finish")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskOnCancelOnFinished(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
|
@ -83,11 +108,13 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
ExpectNoError(t, GracefulShutdown(1*time.Second))
|
||||
ExpectNoError(t, gracefulShutdown(1*time.Second))
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
ExpectTrue(t, finished)
|
||||
|
||||
<-root.finished
|
||||
ExpectError(t, context.Canceled, task.Context().Err())
|
||||
ExpectTrue(t, root.waitFinish(1*time.Second))
|
||||
ExpectError(t, ErrProgramExiting, context.Cause(task.Context()))
|
||||
ExpectError(t, ErrProgramExiting, task.Context().Err())
|
||||
ExpectError(t, ErrProgramExiting, task.FinishCause())
|
||||
}
|
||||
|
||||
|
@ -95,7 +122,7 @@ func TestTimeoutOnGracefulShutdown(t *testing.T) {
|
|||
t.Cleanup(testCleanup)
|
||||
_ = testTask()
|
||||
|
||||
ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
|
||||
ExpectError(t, context.DeadlineExceeded, gracefulShutdown(time.Millisecond))
|
||||
}
|
||||
|
||||
func TestFinishMultipleCalls(t *testing.T) {
|
||||
|
@ -112,10 +139,26 @@ func TestFinishMultipleCalls(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkTasks(b *testing.B) {
|
||||
for range b.N {
|
||||
func BenchmarkTasksNoFinish(b *testing.B) {
|
||||
for b.Loop() {
|
||||
task := RootTask("", false)
|
||||
task.Subtask("", false).Finish(nil)
|
||||
task.Finish(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTasksNeedFinish(b *testing.B) {
|
||||
for b.Loop() {
|
||||
task := testTask()
|
||||
task.Subtask("", true).Finish(nil)
|
||||
task.Finish(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContextWithCancel(b *testing.B) {
|
||||
for b.Loop() {
|
||||
task, taskCancel := context.WithCancel(b.Context())
|
||||
taskCancel()
|
||||
<-task.Done()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package task
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -10,73 +9,34 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
var ErrProgramExiting = errors.New("program exiting")
|
||||
|
||||
var (
|
||||
root = newRoot()
|
||||
allTasks = F.NewSet[*Task]()
|
||||
)
|
||||
|
||||
func testCleanup() {
|
||||
root = newRoot()
|
||||
allTasks.Clear()
|
||||
}
|
||||
|
||||
// RootTask returns a new Task with the given name, derived from the root context.
|
||||
func RootTask(name string, needFinish ...bool) *Task {
|
||||
return root.Subtask(name, needFinish...)
|
||||
}
|
||||
|
||||
func newRoot() *Task {
|
||||
t := &Task{
|
||||
name: "root",
|
||||
childrenDone: make(chan struct{}),
|
||||
finished: make(chan struct{}),
|
||||
}
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
return t
|
||||
//
|
||||
//go:inline
|
||||
func RootTask(name string, needFinish bool) *Task {
|
||||
return root.Subtask(name, needFinish)
|
||||
}
|
||||
|
||||
func RootContext() context.Context {
|
||||
return root.ctx
|
||||
return root
|
||||
}
|
||||
|
||||
func RootContextCanceled() <-chan struct{} {
|
||||
return root.ctx.Done()
|
||||
return root.Done()
|
||||
}
|
||||
|
||||
func OnProgramExit(about string, fn func()) {
|
||||
root.OnFinished(about, fn)
|
||||
}
|
||||
|
||||
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
|
||||
// WaitExit waits for a signal to shutdown the program, and then waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GracefulShutdown(timeout time.Duration) (err error) {
|
||||
go root.Finish(ErrProgramExiting)
|
||||
|
||||
after := time.After(timeout)
|
||||
for {
|
||||
select {
|
||||
case <-root.finished:
|
||||
return
|
||||
case <-after:
|
||||
b, err := json.Marshal(DebugTaskList())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed to marshal tasks")
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WaitExit(shutdownTimeout int) {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT)
|
||||
|
@ -88,5 +48,22 @@ func WaitExit(shutdownTimeout int) {
|
|||
|
||||
// gracefully shutdown
|
||||
log.Info().Msg("shutting down")
|
||||
_ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout))
|
||||
if err := gracefulShutdown(time.Second * time.Duration(shutdownTimeout)); err != nil {
|
||||
root.reportStucked()
|
||||
}
|
||||
}
|
||||
|
||||
// gracefulShutdown waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func gracefulShutdown(timeout time.Duration) error {
|
||||
root.Finish(ErrProgramExiting)
|
||||
root.finishChildren()
|
||||
root.runCallbacks()
|
||||
if !root.waitFinish(timeout) {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
9
internal/utils/testing/log.go
Normal file
9
internal/utils/testing/log.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package expect
|
||||
|
||||
import "github.com/rs/zerolog"
|
||||
|
||||
func init() {
|
||||
if isTest {
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
}
|
|
@ -56,7 +56,7 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
|
|||
fwMap: make(map[string]*fileWatcher),
|
||||
eventCh: make(chan Event),
|
||||
errCh: make(chan gperr.Error),
|
||||
task: parent.Subtask("dir_watcher(" + dirPath + ")"),
|
||||
task: parent.Subtask("dir_watcher("+dirPath+")", true),
|
||||
}
|
||||
go helper.start()
|
||||
return helper
|
||||
|
|
Loading…
Add table
Reference in a new issue