mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-01 01:22:34 +02:00
fix(task): refactor task module and fix reload stuck/error, fix some logic
This commit is contained in:
parent
c90795e614
commit
2628d9e8a8
14 changed files with 371 additions and 443 deletions
|
@ -117,13 +117,13 @@ func Reload() gperr.Error {
|
||||||
newCfg := newConfig()
|
newCfg := newConfig()
|
||||||
err := newCfg.load()
|
err := newCfg.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newCfg.task.Finish(err)
|
newCfg.task.FinishAndWait(err)
|
||||||
return gperr.New(ansi.Warning("using last config")).With(err)
|
return gperr.New(ansi.Warning("using last config")).With(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cancel all current subtasks -> wait
|
// cancel all current subtasks -> wait
|
||||||
// -> replace config -> start new subtasks
|
// -> replace config -> start new subtasks
|
||||||
config.GetInstance().(*Config).Task().Finish("config changed")
|
config.GetInstance().(*Config).Task().FinishAndWait("config changed")
|
||||||
newCfg.Start(StartAllServers)
|
newCfg.Start(StartAllServers)
|
||||||
config.SetInstance(newCfg)
|
config.SetInstance(newCfg)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -58,20 +58,16 @@ func initClientCleaner() {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
closeTimedOutClients()
|
closeTimedOutClients()
|
||||||
case <-cleaner.Context().Done():
|
case <-cleaner.Context().Done():
|
||||||
|
clientMapMu.Lock()
|
||||||
|
for _, c := range clientMap {
|
||||||
|
delete(clientMap, c.Key())
|
||||||
|
c.Client.Close()
|
||||||
|
}
|
||||||
|
clientMapMu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
task.OnProgramExit("docker_clients_cleanup", func() {
|
|
||||||
clientMapMu.Lock()
|
|
||||||
defer clientMapMu.Unlock()
|
|
||||||
|
|
||||||
for _, c := range clientMap {
|
|
||||||
delete(clientMap, c.Key())
|
|
||||||
c.Client.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeTimedOutClients() {
|
func closeTimedOutClients() {
|
||||||
|
|
|
@ -92,7 +92,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
||||||
// same address, likely two routes from the same container
|
// same address, likely two routes from the same container
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
w.task.Finish(causeReload)
|
w.task.FinishAndWait(causeReload)
|
||||||
}
|
}
|
||||||
watcherMapMu.RUnlock()
|
watcherMapMu.RUnlock()
|
||||||
|
|
||||||
|
@ -156,14 +156,15 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
||||||
w.task = parent.Subtask("idlewatcher."+r.Name(), true)
|
w.task = parent.Subtask("idlewatcher."+r.Name(), true)
|
||||||
|
|
||||||
watcherMapMu.Lock()
|
watcherMapMu.Lock()
|
||||||
defer watcherMapMu.Unlock()
|
|
||||||
watcherMap[key] = w
|
watcherMap[key] = w
|
||||||
|
watcherMapMu.Unlock()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
cause := w.watchUntilDestroy()
|
cause := w.watchUntilDestroy()
|
||||||
if errors.Is(cause, causeContainerDestroy) || errors.Is(cause, task.ErrProgramExiting) {
|
if errors.Is(cause, causeContainerDestroy) || errors.Is(cause, task.ErrProgramExiting) {
|
||||||
watcherMapMu.Lock()
|
watcherMapMu.Lock()
|
||||||
defer watcherMapMu.Unlock()
|
|
||||||
delete(watcherMap, key)
|
delete(watcherMap, key)
|
||||||
|
watcherMapMu.Unlock()
|
||||||
w.l.Info().Msg("idlewatcher stopped")
|
w.l.Info().Msg("idlewatcher stopped")
|
||||||
} else if !errors.Is(cause, causeReload) {
|
} else if !errors.Is(cause, causeReload) {
|
||||||
gperr.LogError("idlewatcher stopped unexpectedly", cause, &w.l)
|
gperr.LogError("idlewatcher stopped unexpectedly", cause, &w.l)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -10,7 +11,7 @@ import (
|
||||||
|
|
||||||
func convertError(err error) error {
|
func convertError(err error) error {
|
||||||
switch {
|
switch {
|
||||||
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
|
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled), errors.Is(err, net.ErrClosed):
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -3,6 +3,8 @@ package server
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -105,7 +107,7 @@ func (s *Server) Start(parent task.Parent) {
|
||||||
}
|
}
|
||||||
Start(subtask, h3, s.acl, &s.l)
|
Start(subtask, h3, s.acl, &s.l)
|
||||||
if s.http != nil {
|
if s.http != nil {
|
||||||
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
|
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
|
||||||
}
|
}
|
||||||
// s.https is not nil (checked above)
|
// s.https is not nil (checked above)
|
||||||
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
|
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
|
||||||
|
@ -115,7 +117,7 @@ func (s *Server) Start(parent task.Parent) {
|
||||||
Start(subtask, s.https, s.acl, &s.l)
|
Start(subtask, s.https, s.acl, &s.l)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, logger *zerolog.Logger) {
|
func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, logger *zerolog.Logger) (port int) {
|
||||||
if srv == nil {
|
if srv == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -138,6 +140,7 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l
|
||||||
HandleError(logger, err, "failed to listen on port")
|
HandleError(logger, err, "failed to listen on port")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
port = l.Addr().(*net.TCPAddr).Port
|
||||||
if srv.TLSConfig != nil {
|
if srv.TLSConfig != nil {
|
||||||
l = tls.NewListener(l, srv.TLSConfig)
|
l = tls.NewListener(l, srv.TLSConfig)
|
||||||
}
|
}
|
||||||
|
@ -145,32 +148,36 @@ func Start[Server httpServer](parent task.Parent, srv Server, acl *acl.Config, l
|
||||||
l = acl.WrapTCP(l)
|
l = acl.WrapTCP(l)
|
||||||
}
|
}
|
||||||
serveFunc = getServeFunc(l, srv.Serve)
|
serveFunc = getServeFunc(l, srv.Serve)
|
||||||
|
task.OnCancel("stop", func() {
|
||||||
|
stop(srv, l, logger)
|
||||||
|
})
|
||||||
case *http3.Server:
|
case *http3.Server:
|
||||||
l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr)
|
l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
HandleError(logger, err, "failed to listen on port")
|
HandleError(logger, err, "failed to listen on port")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
port = l.LocalAddr().(*net.UDPAddr).Port
|
||||||
if acl != nil {
|
if acl != nil {
|
||||||
l = acl.WrapUDP(l)
|
l = acl.WrapUDP(l)
|
||||||
}
|
}
|
||||||
serveFunc = getServeFunc(l, srv.Serve)
|
serveFunc = getServeFunc(l, srv.Serve)
|
||||||
|
task.OnCancel("stop", func() {
|
||||||
|
stop(srv, l, logger)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
task.OnCancel("stop", func() {
|
|
||||||
stop(srv, logger)
|
|
||||||
})
|
|
||||||
logStarted(srv, logger)
|
logStarted(srv, logger)
|
||||||
go func() {
|
go func() {
|
||||||
err := convertError(serveFunc())
|
err := convertError(serveFunc())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
HandleError(logger, err, "failed to serve "+proto+" server")
|
HandleError(logger, err, "failed to serve "+proto+" server")
|
||||||
}
|
}
|
||||||
task.Finish(err)
|
task.Finish(err)
|
||||||
}()
|
}()
|
||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
|
func stop[Server httpServer](srv Server, l io.Closer, logger *zerolog.Logger) {
|
||||||
if srv == nil {
|
if srv == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -180,7 +187,7 @@ func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
|
||||||
ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second)
|
ctx, cancel := context.WithTimeout(task.RootContext(), 1*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := convertError(srv.Shutdown(ctx)); err != nil {
|
if err := convertError(errors.Join(srv.Shutdown(ctx), l.Close())); err != nil {
|
||||||
HandleError(logger, err, "failed to shutdown "+proto+" server")
|
HandleError(logger, err, "failed to shutdown "+proto+" server")
|
||||||
} else {
|
} else {
|
||||||
logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped")
|
logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped")
|
||||||
|
|
|
@ -100,18 +100,21 @@ func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error {
|
||||||
func (p *Provider) Start(parent task.Parent) gperr.Error {
|
func (p *Provider) Start(parent task.Parent) gperr.Error {
|
||||||
t := parent.Subtask("provider."+p.String(), false)
|
t := parent.Subtask("provider."+p.String(), false)
|
||||||
|
|
||||||
|
routesTask := t.Subtask("routes", false)
|
||||||
errs := gperr.NewBuilder("routes error")
|
errs := gperr.NewBuilder("routes error")
|
||||||
for _, r := range p.routes {
|
for _, r := range p.routes {
|
||||||
errs.Add(p.startRoute(t, r))
|
errs.Add(p.startRoute(routesTask, r))
|
||||||
}
|
}
|
||||||
|
|
||||||
eventQueue := events.NewEventQueue(
|
eventQueue := events.NewEventQueue(
|
||||||
t.Subtask("event_queue", false),
|
t.Subtask("event_queue", false),
|
||||||
providerEventFlushInterval,
|
providerEventFlushInterval,
|
||||||
func(events []events.Event) {
|
func(events []events.Event) {
|
||||||
|
routesTask.FinishAndWait("reload routes")
|
||||||
|
routesTask = t.Subtask("routes", false)
|
||||||
handler := p.newEventHandler()
|
handler := p.newEventHandler()
|
||||||
// routes' lifetime should follow the provider's lifetime
|
// routes' lifetime should follow the provider's lifetime
|
||||||
handler.Handle(t, events)
|
handler.Handle(routesTask, events)
|
||||||
handler.Log()
|
handler.Log()
|
||||||
},
|
},
|
||||||
func(err gperr.Error) {
|
func(err gperr.Error) {
|
||||||
|
|
|
@ -8,45 +8,59 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// debug only.
|
// debug only.
|
||||||
func (t *Task) listStuckedCallbacks() []string {
|
func listStuckedCallbacks(t *Task) []string {
|
||||||
t.mu.Lock()
|
callbacks := make([]string, 0)
|
||||||
defer t.mu.Unlock()
|
if t.onFinish != nil {
|
||||||
callbacks := make([]string, 0, len(t.callbacksOnFinish))
|
for c := range t.onFinish.Range {
|
||||||
for c := range t.callbacksOnFinish {
|
callbacks = append(callbacks, c.about)
|
||||||
callbacks = append(callbacks, c.about)
|
}
|
||||||
|
}
|
||||||
|
if t.onCancel != nil {
|
||||||
|
for c := range t.onCancel.Range {
|
||||||
|
callbacks = append(callbacks, c.about)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t.children != nil {
|
||||||
|
for c := range t.children.Range {
|
||||||
|
callbacks = append(callbacks, listStuckedCallbacks(c)...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return callbacks
|
return callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
// debug only.
|
// debug only.
|
||||||
func (t *Task) listStuckedChildren() []string {
|
func listStuckedChildren(t *Task) []string {
|
||||||
t.mu.Lock()
|
if t.children != nil {
|
||||||
defer t.mu.Unlock()
|
children := make([]string, 0)
|
||||||
children := make([]string, 0, len(t.children))
|
for c := range t.children.Range {
|
||||||
for c := range t.children {
|
children = append(children, c.String())
|
||||||
if c.isFinished() {
|
children = append(children, listStuckedCallbacks(c)...)
|
||||||
continue
|
|
||||||
}
|
|
||||||
children = append(children, c.String())
|
|
||||||
if len(c.children) > 0 {
|
|
||||||
children = append(children, c.listStuckedChildren()...)
|
|
||||||
}
|
}
|
||||||
|
return children
|
||||||
}
|
}
|
||||||
return children
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Task) reportStucked() {
|
func (t *Task) reportStucked() {
|
||||||
callbacks := t.listStuckedCallbacks()
|
callbacks := listStuckedCallbacks(t)
|
||||||
children := t.listStuckedChildren()
|
children := listStuckedChildren(t)
|
||||||
if len(callbacks) == 0 && len(children) == 0 {
|
if len(callbacks) == 0 && len(children) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmtOutput := gperr.NewBuilder(fmt.Sprintf("%s stucked callbacks: %d, stucked children: %d", t.String(), len(callbacks), len(children)))
|
fmtOutput := gperr.NewBuilder(fmt.Sprintf("%s stucked callbacks: %d, stucked children: %d", t.String(), len(callbacks), len(children)))
|
||||||
if len(callbacks) > 0 {
|
if len(callbacks) > 0 {
|
||||||
fmtOutput.Add(gperr.New("callbacks").With(gperr.Multiline().AddLinesString(callbacks...)))
|
callbackBuilder := gperr.NewBuilder("callbacks")
|
||||||
|
for _, c := range callbacks {
|
||||||
|
callbackBuilder.Adds(c)
|
||||||
|
}
|
||||||
|
fmtOutput.Add(callbackBuilder.Error())
|
||||||
}
|
}
|
||||||
if len(children) > 0 {
|
if len(children) > 0 {
|
||||||
fmtOutput.Add(gperr.New("children").With(gperr.Multiline().AddLinesString(children...)))
|
childrenBuilder := gperr.NewBuilder("children")
|
||||||
|
for _, c := range children {
|
||||||
|
childrenBuilder.Adds(c)
|
||||||
|
}
|
||||||
|
fmtOutput.Add(childrenBuilder.Error())
|
||||||
}
|
}
|
||||||
log.Warn().Msg(fmtOutput.String())
|
log.Warn().Msg(fmtOutput.String())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,237 +0,0 @@
|
||||||
package task
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
taskPool = make(chan *Task, 100)
|
|
||||||
|
|
||||||
voidTask = &Task{ctx: context.Background()}
|
|
||||||
root = newRoot()
|
|
||||||
|
|
||||||
cancelCtx context.Context
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel()
|
|
||||||
cancelCtx = ctx //nolint:fatcontext
|
|
||||||
|
|
||||||
voidTask.parent = root
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCleanup() {
|
|
||||||
root = newRoot()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRoot() *Task {
|
|
||||||
return newTask("root", voidTask, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func noCancel(error) {
|
|
||||||
// do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func newTask(name string, parent *Task, needFinish bool) *Task {
|
|
||||||
var t *Task
|
|
||||||
select {
|
|
||||||
case t = <-taskPool:
|
|
||||||
t.finished.Store(false)
|
|
||||||
default:
|
|
||||||
t = &Task{}
|
|
||||||
}
|
|
||||||
t.name = name
|
|
||||||
t.parent = parent
|
|
||||||
if needFinish {
|
|
||||||
t.ctx, t.cancel = context.WithCancelCause(parent.ctx)
|
|
||||||
} else {
|
|
||||||
t.ctx, t.cancel = parent.ctx, noCancel
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func (t *Task) needFinish() bool {
|
|
||||||
return t.ctx != t.parent.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func (t *Task) isCanceled() bool {
|
|
||||||
return t.cancel == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func putTask(t *Task) {
|
|
||||||
select {
|
|
||||||
case taskPool <- t:
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
|
|
||||||
if !t.needFinish() {
|
|
||||||
if waitSubTasks {
|
|
||||||
t.parent.addCallback(about, func() {
|
|
||||||
if !t.waitFinish(taskTimeout) {
|
|
||||||
t.reportStucked()
|
|
||||||
}
|
|
||||||
fn()
|
|
||||||
}, false)
|
|
||||||
} else {
|
|
||||||
t.parent.addCallback(about, fn, false)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !waitSubTasks {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
if t.callbacksOnCancel == nil {
|
|
||||||
t.callbacksOnCancel = make(callbacksSet)
|
|
||||||
go func() {
|
|
||||||
<-t.ctx.Done()
|
|
||||||
for c := range t.callbacksOnCancel {
|
|
||||||
go func() {
|
|
||||||
invokeWithRecover(c)
|
|
||||||
t.mu.Lock()
|
|
||||||
delete(t.callbacksOnCancel, c)
|
|
||||||
t.mu.Unlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
t.callbacksOnCancel[&Callback{fn: fn, about: about}] = struct{}{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
|
|
||||||
if t.isCanceled() {
|
|
||||||
log.Panic().
|
|
||||||
Str("task", t.String()).
|
|
||||||
Str("callback", about).
|
|
||||||
Msg("callback added to canceled task")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.callbacksOnFinish == nil {
|
|
||||||
t.callbacksOnFinish = make(callbacksSet)
|
|
||||||
}
|
|
||||||
t.callbacksOnFinish[&Callback{
|
|
||||||
fn: fn,
|
|
||||||
about: about,
|
|
||||||
}] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func (t *Task) addChild(child *Task) {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
|
|
||||||
if t.isCanceled() {
|
|
||||||
log.Panic().
|
|
||||||
Str("task", t.String()).
|
|
||||||
Str("child", child.Name()).
|
|
||||||
Msg("child added to canceled task")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.children == nil {
|
|
||||||
t.children = make(childrenSet)
|
|
||||||
}
|
|
||||||
t.children[child] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func (t *Task) removeChild(child *Task) {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
delete(t.children, child)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) runOnFinishCallbacks() {
|
|
||||||
if len(t.callbacksOnFinish) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for c := range t.callbacksOnFinish {
|
|
||||||
go func() {
|
|
||||||
invokeWithRecover(c)
|
|
||||||
t.mu.Lock()
|
|
||||||
delete(t.callbacksOnFinish, c)
|
|
||||||
t.mu.Unlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) waitFinish(timeout time.Duration) bool {
|
|
||||||
// return directly if already finished
|
|
||||||
if t.isFinished() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
children, callbacksOnCancel, callbacksOnFinish := t.children, t.callbacksOnCancel, t.callbacksOnFinish
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
ok := true
|
|
||||||
if len(children) != 0 {
|
|
||||||
ok = waitEmpty(children, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(callbacksOnCancel) != 0 {
|
|
||||||
ok = ok && waitEmpty(callbacksOnCancel, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(callbacksOnFinish) != 0 {
|
|
||||||
ok = ok && waitEmpty(callbacksOnFinish, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
|
|
||||||
if len(set) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
timer := time.NewTimer(timeout)
|
|
||||||
defer timer.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
if len(set) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:inline
|
|
||||||
func fmtCause(cause any) error {
|
|
||||||
switch cause := cause.(type) {
|
|
||||||
case nil:
|
|
||||||
return nil
|
|
||||||
case error:
|
|
||||||
return cause
|
|
||||||
case string:
|
|
||||||
return errors.New(cause)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("%v", cause)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,15 +1,10 @@
|
||||||
// This file has the abstract logic of the task system.
|
|
||||||
//
|
|
||||||
// The implementation of the task system is in the impl.go file.
|
|
||||||
package task
|
package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,21 +30,23 @@ type (
|
||||||
//
|
//
|
||||||
// Use Task.Finish to stop all subtasks of the Task.
|
// Use Task.Finish to stop all subtasks of the Task.
|
||||||
Task struct {
|
Task struct {
|
||||||
name string
|
parent *Task
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelCauseFunc
|
||||||
|
done chan struct{}
|
||||||
|
finishCalled bool
|
||||||
|
onCancel *withWg[*Callback]
|
||||||
|
onFinish *withWg[*Callback]
|
||||||
|
children *withWg[*Task]
|
||||||
|
|
||||||
parent *Task
|
mu sync.Mutex
|
||||||
children childrenSet
|
|
||||||
callbacksOnFinish callbacksSet
|
|
||||||
callbacksOnCancel callbacksSet
|
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelCauseFunc
|
|
||||||
|
|
||||||
finished atomic.Bool
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
Parent interface {
|
Parent interface {
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
|
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||||
|
//
|
||||||
|
// This should not be called after Finish is called on the task or its parent task.
|
||||||
Subtask(name string, needFinish bool) *Task
|
Subtask(name string, needFinish bool) *Task
|
||||||
Name() string
|
Name() string
|
||||||
Finish(reason any)
|
Finish(reason any)
|
||||||
|
@ -57,124 +54,193 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
|
||||||
childrenSet = map[*Task]struct{}
|
|
||||||
callbacksSet = map[*Callback]struct{}
|
|
||||||
)
|
|
||||||
|
|
||||||
const taskTimeout = 3 * time.Second
|
const taskTimeout = 3 * time.Second
|
||||||
|
|
||||||
func (t *Task) Context() context.Context {
|
func (t *Task) Context() context.Context {
|
||||||
return t.ctx
|
return t.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinishCause returns the reason / error that caused the task to be finished.
|
|
||||||
func (t *Task) FinishCause() error {
|
|
||||||
return context.Cause(t.ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnFinished calls fn when the task is canceled and all subtasks are finished.
|
|
||||||
//
|
|
||||||
// It should not be called after Finish is called.
|
|
||||||
func (t *Task) OnFinished(about string, fn func()) {
|
|
||||||
t.addCallback(about, fn, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnCancel calls fn when the task is canceled.
|
|
||||||
//
|
|
||||||
// It should not be called after Finish is called.
|
|
||||||
func (t *Task) OnCancel(about string, fn func()) {
|
|
||||||
t.addCallback(about, fn, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finish cancel all subtasks and wait for them to finish,
|
|
||||||
// then marks the task as finished, with the given reason (if any).
|
|
||||||
func (t *Task) Finish(reason any) {
|
|
||||||
t.mu.Lock()
|
|
||||||
if t.isCanceled() {
|
|
||||||
t.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.cancel(fmtCause(reason))
|
|
||||||
t.ctx, t.cancel = cancelCtx, nil
|
|
||||||
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
t.finishAndWait()
|
|
||||||
t.finished.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) finishAndWait() {
|
|
||||||
ok := true
|
|
||||||
|
|
||||||
if !waitEmpty(t.children, taskTimeout) {
|
|
||||||
t.reportStucked()
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
t.runOnFinishCallbacks()
|
|
||||||
|
|
||||||
if !t.waitFinish(taskTimeout) {
|
|
||||||
t.reportStucked()
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
// clear anyway
|
|
||||||
clear(t.children)
|
|
||||||
clear(t.callbacksOnFinish)
|
|
||||||
|
|
||||||
if t != root && t.needFinish() {
|
|
||||||
t.parent.removeChild(t)
|
|
||||||
}
|
|
||||||
logFinished(t)
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
putTask(t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) isFinished() bool {
|
|
||||||
return t.finished.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
|
||||||
//
|
|
||||||
// This should not be called after Finish is called on the task or its parent task.
|
|
||||||
func (t *Task) Subtask(name string, needFinish bool) *Task {
|
|
||||||
panicIfFinished(t, "Subtask is called")
|
|
||||||
|
|
||||||
child := newTask(name, t, needFinish)
|
|
||||||
|
|
||||||
if needFinish {
|
|
||||||
t.addChild(child)
|
|
||||||
}
|
|
||||||
|
|
||||||
logStarted(child)
|
|
||||||
return child
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns the name of the task without parent names.
|
|
||||||
func (t *Task) Name() string {
|
func (t *Task) Name() string {
|
||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the full name of the task.
|
// String returns the full name of the task.
|
||||||
func (t *Task) String() string {
|
func (t *Task) String() string {
|
||||||
if t.parent != root {
|
return t.fullName()
|
||||||
return t.parent.String() + "." + t.name
|
|
||||||
}
|
|
||||||
return t.name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalText implements encoding.TextMarshaler.
|
// MarshalText implements encoding.TextMarshaler.
|
||||||
func (t *Task) MarshalText() ([]byte, error) {
|
func (t *Task) MarshalText() ([]byte, error) {
|
||||||
return []byte(t.String()), nil
|
return []byte(t.fullName()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func invokeWithRecover(cb *Callback) {
|
// Finish marks the task as finished, with the given reason (if any).
|
||||||
defer func() {
|
func (t *Task) Finish(reason any) {
|
||||||
if err := recover(); err != nil {
|
t.finish(reason, false)
|
||||||
log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic")
|
}
|
||||||
panicWithDebugStack()
|
|
||||||
}
|
// FinishCause returns the reason / error that caused the task to be finished.
|
||||||
}()
|
func (t *Task) FinishCause() error {
|
||||||
cb.fn()
|
return context.Cause(t.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinishAndWait cancel all subtasks and wait for them to finish,
|
||||||
|
// then marks the task as finished, with the given reason (if any).
|
||||||
|
func (t *Task) FinishAndWait(reason any) {
|
||||||
|
t.finish(reason, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnFinished calls fn when the task is canceled and all subtasks are finished.
|
||||||
|
//
|
||||||
|
// It should not be called after Finish is called.
|
||||||
|
func (t *Task) OnFinished(about string, fn func()) {
|
||||||
|
if !t.needFinish() {
|
||||||
|
t.OnCancel(about, fn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.onFinish == nil {
|
||||||
|
t.onFinish = newWithWg[*Callback]()
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-t.ctx.Done()
|
||||||
|
<-t.done
|
||||||
|
for cb := range t.onFinish.Range {
|
||||||
|
go func(cb *Callback) {
|
||||||
|
invokeWithRecover(cb)
|
||||||
|
t.onFinish.Delete(cb)
|
||||||
|
}(cb)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.onFinish.Add(&Callback{fn: fn, about: about})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnCancel calls fn when the task is canceled.
|
||||||
|
//
|
||||||
|
// It should not be called after Finish is called.
|
||||||
|
func (t *Task) OnCancel(about string, fn func()) {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.onCancel == nil {
|
||||||
|
t.onCancel = newWithWg[*Callback]()
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-t.ctx.Done()
|
||||||
|
for cb := range t.onCancel.Range {
|
||||||
|
go func(cb *Callback) {
|
||||||
|
invokeWithRecover(cb)
|
||||||
|
t.onCancel.Delete(cb)
|
||||||
|
}(cb)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.onCancel.Add(&Callback{fn: fn, about: about})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||||
|
//
|
||||||
|
// This should not be called after Finish is called on the task or its parent task.
|
||||||
|
func (t *Task) Subtask(name string, needFinish bool) *Task {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.children == nil {
|
||||||
|
t.children = newWithWg[*Task]()
|
||||||
|
t.mu.Unlock()
|
||||||
|
} else {
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
child := &Task{
|
||||||
|
name: name,
|
||||||
|
parent: t,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.children.Add(child)
|
||||||
|
|
||||||
|
child.ctx, child.cancel = context.WithCancelCause(t.ctx)
|
||||||
|
|
||||||
|
if needFinish {
|
||||||
|
child.done = make(chan struct{})
|
||||||
|
} else {
|
||||||
|
child.done = closedCh
|
||||||
|
go func() {
|
||||||
|
<-child.ctx.Done()
|
||||||
|
child.Finish(t.FinishCause())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
logStarted(child)
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) finish(reason any, wait bool) {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.finishCalled {
|
||||||
|
t.mu.Unlock()
|
||||||
|
// wait but not report stucked (again)
|
||||||
|
t.waitFinish(taskTimeout)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.finishCalled = true
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
if t.needFinish() {
|
||||||
|
close(t.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.cancel(fmtCause(reason))
|
||||||
|
if wait && !t.waitFinish(taskTimeout) {
|
||||||
|
t.reportStucked()
|
||||||
|
}
|
||||||
|
if t != root {
|
||||||
|
t.parent.children.Delete(t)
|
||||||
|
}
|
||||||
|
logFinished(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) waitFinish(timeout time.Duration) bool {
|
||||||
|
if t.children == nil && t.onCancel == nil && t.onFinish == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
if t.children != nil {
|
||||||
|
t.children.Wait()
|
||||||
|
}
|
||||||
|
if t.onCancel != nil {
|
||||||
|
t.onCancel.Wait()
|
||||||
|
}
|
||||||
|
if t.onFinish != nil {
|
||||||
|
t.onFinish.Wait()
|
||||||
|
}
|
||||||
|
<-t.done
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
timeoutCh := time.After(timeout)
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return true
|
||||||
|
case <-timeoutCh:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) fullName() string {
|
||||||
|
if t.parent == root {
|
||||||
|
return t.name
|
||||||
|
}
|
||||||
|
return t.parent.fullName() + "." + t.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) needFinish() bool {
|
||||||
|
return t.done != closedCh
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,12 +12,6 @@ func panicWithDebugStack() {
|
||||||
panic(string(debug.Stack()))
|
panic(string(debug.Stack()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func panicIfFinished(t *Task, reason string) {
|
|
||||||
if t.isFinished() {
|
|
||||||
log.Panic().Msg("task " + t.String() + " is finished but " + reason)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func logStarted(t *Task) {
|
func logStarted(t *Task) {
|
||||||
log.Info().Msg("task " + t.String() + " started")
|
log.Info().Msg("task " + t.String() + " started")
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,10 +6,6 @@ func panicWithDebugStack() {
|
||||||
// do nothing
|
// do nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
func panicIfFinished(t *Task, reason string) {
|
|
||||||
// do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
func logStarted(t *Task) {
|
func logStarted(t *Task) {
|
||||||
// do nothing
|
// do nothing
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testTask() *Task {
|
func testTask() *Task {
|
||||||
|
@ -35,7 +35,7 @@ func TestChildTaskCancellation(t *testing.T) {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-child.Context().Done():
|
case <-child.Context().Done():
|
||||||
ExpectError(t, context.Canceled, child.Context().Err())
|
expect.ErrorIs(t, context.Canceled, child.Context().Err())
|
||||||
default:
|
default:
|
||||||
t.Fatal("subTask context was not canceled as expected")
|
t.Fatal("subTask context was not canceled as expected")
|
||||||
}
|
}
|
||||||
|
@ -80,10 +80,10 @@ func TestTaskOnCancelOnFinished(t *testing.T) {
|
||||||
shouldTrueOnFinish = true
|
shouldTrueOnFinish = true
|
||||||
})
|
})
|
||||||
|
|
||||||
ExpectFalse(t, shouldTrueOnFinish)
|
expect.False(t, shouldTrueOnFinish)
|
||||||
task.Finish(nil)
|
task.Finish(nil)
|
||||||
ExpectTrue(t, shouldTrueOnCancel)
|
expect.True(t, shouldTrueOnCancel)
|
||||||
ExpectTrue(t, shouldTrueOnFinish)
|
expect.True(t, shouldTrueOnFinish)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
||||||
|
@ -108,29 +108,28 @@ func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ExpectNoError(t, gracefulShutdown(1*time.Second))
|
expect.NoError(t, gracefulShutdown(1*time.Second))
|
||||||
time.Sleep(100 * time.Millisecond)
|
expect.True(t, finished)
|
||||||
ExpectTrue(t, finished)
|
|
||||||
|
|
||||||
ExpectTrue(t, root.waitFinish(1*time.Second))
|
expect.ErrorIs(t, ErrProgramExiting, context.Cause(task.Context()))
|
||||||
ExpectError(t, context.Canceled, context.Cause(task.Context()))
|
expect.ErrorIs(t, context.Canceled, task.Context().Err())
|
||||||
ExpectError(t, ErrProgramExiting, task.Context().Err())
|
expect.ErrorIs(t, ErrProgramExiting, task.FinishCause())
|
||||||
ExpectError(t, ErrProgramExiting, task.FinishCause())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeoutOnGracefulShutdown(t *testing.T) {
|
func TestTimeoutOnGracefulShutdown(t *testing.T) {
|
||||||
t.Cleanup(testCleanup)
|
t.Cleanup(testCleanup)
|
||||||
_ = testTask()
|
_ = testTask()
|
||||||
|
|
||||||
ExpectError(t, context.DeadlineExceeded, gracefulShutdown(time.Millisecond))
|
expect.ErrorIs(t, context.DeadlineExceeded, gracefulShutdown(time.Millisecond))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFinishMultipleCalls(t *testing.T) {
|
func TestFinishMultipleCalls(t *testing.T) {
|
||||||
t.Cleanup(testCleanup)
|
t.Cleanup(testCleanup)
|
||||||
task := testTask()
|
task := testTask()
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(5)
|
n := 20
|
||||||
for range 5 {
|
wg.Add(n)
|
||||||
|
for range n {
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
task.Finish(nil)
|
task.Finish(nil)
|
||||||
|
@ -157,8 +156,8 @@ func BenchmarkTasksNeedFinish(b *testing.B) {
|
||||||
|
|
||||||
func BenchmarkContextWithCancel(b *testing.B) {
|
func BenchmarkContextWithCancel(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
task, taskCancel := context.WithCancel(b.Context())
|
task, taskCancel := context.WithCancelCause(b.Context())
|
||||||
taskCancel()
|
taskCancel(nil)
|
||||||
<-task.Done()
|
<-task.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package task
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -13,6 +14,31 @@ import (
|
||||||
|
|
||||||
var ErrProgramExiting = errors.New("program exiting")
|
var ErrProgramExiting = errors.New("program exiting")
|
||||||
|
|
||||||
|
var root *Task
|
||||||
|
|
||||||
|
var closedCh = make(chan struct{})
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
close(closedCh)
|
||||||
|
initRoot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initRoot() {
|
||||||
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
root = &Task{
|
||||||
|
name: "root",
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
done: closedCh,
|
||||||
|
}
|
||||||
|
root.parent = root
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCleanup() {
|
||||||
|
root.cancel(nil)
|
||||||
|
initRoot()
|
||||||
|
}
|
||||||
|
|
||||||
// RootTask returns a new Task with the given name, derived from the root context.
|
// RootTask returns a new Task with the given name, derived from the root context.
|
||||||
//
|
//
|
||||||
//go:inline
|
//go:inline
|
||||||
|
@ -29,7 +55,7 @@ func RootContextCanceled() <-chan struct{} {
|
||||||
}
|
}
|
||||||
|
|
||||||
func OnProgramExit(about string, fn func()) {
|
func OnProgramExit(about string, fn func()) {
|
||||||
root.OnFinished(about, fn)
|
root.OnCancel(about, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitExit waits for a signal to shutdown the program, and then waits for all tasks to finish, up to the given timeout.
|
// WaitExit waits for a signal to shutdown the program, and then waits for all tasks to finish, up to the given timeout.
|
||||||
|
@ -59,19 +85,33 @@ func WaitExit(shutdownTimeout int) {
|
||||||
// still running when the timeout was reached, and their current tree
|
// still running when the timeout was reached, and their current tree
|
||||||
// of subtasks.
|
// of subtasks.
|
||||||
func gracefulShutdown(timeout time.Duration) error {
|
func gracefulShutdown(timeout time.Duration) error {
|
||||||
root.mu.Lock()
|
go root.Finish(ErrProgramExiting)
|
||||||
if root.isCanceled() {
|
if !root.waitFinish(timeout) {
|
||||||
cause := context.Cause(root.ctx)
|
|
||||||
root.mu.Unlock()
|
|
||||||
return cause
|
|
||||||
}
|
|
||||||
root.mu.Unlock()
|
|
||||||
|
|
||||||
root.cancel(ErrProgramExiting)
|
|
||||||
ok := waitEmpty(root.children, timeout)
|
|
||||||
root.runOnFinishCallbacks()
|
|
||||||
if !ok || !root.waitFinish(timeout) {
|
|
||||||
return context.DeadlineExceeded
|
return context.DeadlineExceeded
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func invokeWithRecover(cb *Callback) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic")
|
||||||
|
panicWithDebugStack()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
cb.fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:inline
|
||||||
|
func fmtCause(cause any) error {
|
||||||
|
switch cause := cause.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case error:
|
||||||
|
return cause
|
||||||
|
case string:
|
||||||
|
return errors.New(cause)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%v", cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
48
internal/task/with.go
Normal file
48
internal/task/with.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
type withWg[T comparable] struct {
|
||||||
|
m *xsync.Map[T, struct{}]
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWithWg[T comparable]() *withWg[T] {
|
||||||
|
return &withWg[T]{
|
||||||
|
m: xsync.NewMap[T, struct{}](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withWg[T]) Add(ele T) {
|
||||||
|
w.wg.Add(1)
|
||||||
|
w.m.Store(ele, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withWg[T]) AddWithoutWG(ele T) {
|
||||||
|
w.m.Store(ele, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withWg[T]) Delete(key T) {
|
||||||
|
w.wg.Done()
|
||||||
|
w.m.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withWg[T]) DeleteWithoutWG(key T) {
|
||||||
|
w.m.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withWg[T]) Wait() {
|
||||||
|
w.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withWg[T]) Range(yield func(T) bool) {
|
||||||
|
for ele := range w.m.Range {
|
||||||
|
if !yield(ele) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue