simplify task package implementation

This commit is contained in:
yusing 2025-01-01 06:07:32 +08:00
parent e7aaa95ec5
commit 1ab34ed46f
35 changed files with 547 additions and 600 deletions

View file

@ -23,16 +23,16 @@ lint:
enabled: enabled:
- hadolint@2.12.1-beta - hadolint@2.12.1-beta
- actionlint@1.7.4 - actionlint@1.7.4
- checkov@3.2.334 - checkov@3.2.344
- git-diff-check - git-diff-check
- gofmt@1.20.4 - gofmt@1.20.4
- golangci-lint@1.62.2 - golangci-lint@1.62.2
- osv-scanner@1.9.1 - osv-scanner@1.9.2
- oxipng@9.1.3 - oxipng@9.1.3
- prettier@3.4.2 - prettier@3.4.2
- shellcheck@0.10.0 - shellcheck@0.10.0
- shfmt@3.6.0 - shfmt@3.6.0
- trufflehog@3.86.1 - trufflehog@3.88.0
actions: actions:
disabled: disabled:
- trunk-announce - trunk-announce

View file

@ -28,10 +28,10 @@ get:
go get -u ./cmd && go mod tidy go get -u ./cmd && go mod tidy
debug: debug:
GODOXY_DEBUG=1 make run GODOXY_DEBUG=1 BUILD_FLAGS="" make run
debug-trace: debug-trace:
GODOXY_DEBUG=1 GODOXY_TRACE=1 run GODOXY_TRACE=1 make debug
profile: profile:
GODEBUG=gctrace=1 make debug GODEBUG=gctrace=1 make debug

View file

@ -159,8 +159,7 @@ func main() {
// grafully shutdown // grafully shutdown
logging.Info().Msg("shutting down") logging.Info().Msg("shutting down")
task.CancelGlobalContext() _ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown))
_ = task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
} }
func prepareDirectory(dir string) { func prepareDirectory(dir string) {

View file

@ -52,7 +52,7 @@ func List(w http.ResponseWriter, r *http.Request) {
case ListHomepageConfig: case ListHomepageConfig:
U.RespondJSON(w, r, config.HomepageConfig()) U.RespondJSON(w, r, config.HomepageConfig())
case ListTasks: case ListTasks:
U.RespondJSON(w, r, task.DebugTaskMap()) U.RespondJSON(w, r, task.DebugTaskList())
default: default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
} }

View file

@ -153,8 +153,8 @@ func (p *Provider) ScheduleRenewal() {
return return
} }
go func() { go func() {
task := task.GlobalTask("cert renew scheduler") task := task.RootTask("cert-renew-scheduler", true)
defer task.Finish("cert renew scheduler stopped") defer task.Finish(nil)
for { for {
renewalTime := p.ShouldRenewOn() renewalTime := p.ShouldRenewOn()

View file

@ -53,7 +53,7 @@ func newConfig() *Config {
return &Config{ return &Config{
value: types.DefaultConfig(), value: types.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](), providers: F.NewMapOf[string, *proxy.Provider](),
task: task.GlobalTask("config"), task: task.RootTask("config", false),
} }
} }
@ -76,21 +76,19 @@ func MatchDomains() []string {
} }
func WatchChanges() { func WatchChanges() {
task := task.GlobalTask("config watcher") t := task.RootTask("config_watcher", true)
eventQueue := events.NewEventQueue( eventQueue := events.NewEventQueue(
task, t,
configEventFlushInterval, configEventFlushInterval,
OnConfigChange, OnConfigChange,
func(err E.Error) { func(err E.Error) {
E.LogError("config reload error", err, &logger) E.LogError("config reload error", err, &logger)
}, },
) )
eventQueue.Start(cfgWatcher.Events(task.Context())) eventQueue.Start(cfgWatcher.Events(t.Context()))
} }
func OnConfigChange(flushTask *task.Task, ev []events.Event) { func OnConfigChange(ev []events.Event) {
defer flushTask.Finish("config reload complete")
// no matter how many events during the interval // no matter how many events during the interval
// just reload once and check the last event // just reload once and check the last event
switch ev[len(ev)-1].Action { switch ev[len(ev)-1].Action {
@ -116,14 +114,14 @@ func Reload() E.Error {
newCfg := newConfig() newCfg := newConfig()
err := newCfg.load() err := newCfg.load()
if err != nil { if err != nil {
newCfg.task.Finish(err)
return err return err
} }
// cancel all current subtasks -> wait // cancel all current subtasks -> wait
// -> replace config -> start new subtasks // -> replace config -> start new subtasks
instance.task.Finish("config changed") instance.task.Finish("config changed")
instance.task.Wait() instance = newCfg
*instance = *newCfg
instance.StartProxyProviders() instance.StartProxyProviders()
return nil return nil
} }
@ -143,8 +141,7 @@ func (cfg *Config) Task() *task.Task {
func (cfg *Config) StartProxyProviders() { func (cfg *Config) StartProxyProviders() {
errs := cfg.providers.CollectErrorsParallel( errs := cfg.providers.CollectErrorsParallel(
func(_ string, p *proxy.Provider) error { func(_ string, p *proxy.Provider) error {
subtask := cfg.task.Subtask(p.String()) return p.Start(cfg.task)
return p.Start(subtask)
}) })
if err := E.Join(errs...); err != nil { if err := E.Join(errs...); err != nil {
@ -209,9 +206,6 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
} }
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error { func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
subtask := cfg.task.Subtask("load route providers")
defer subtask.Finish("done")
errs := E.NewBuilder("route provider errors") errs := E.NewBuilder("route provider errors")
results := E.NewBuilder("loaded route providers") results := E.NewBuilder("loaded route providers")

View file

@ -38,7 +38,7 @@ var (
) )
func init() { func init() {
task.GlobalTask("close docker clients").OnFinished("", func() { task.OnProgramExit("docker_clients_cleanup", func() {
clientMap.RangeAllParallel(func(_ string, c Client) { clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() { if c.Connected() {
c.Client.Close() c.Client.Close()

View file

@ -38,7 +38,7 @@ const (
// TODO: support stream // TODO: support stream
func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
hcCfg := entry.RawEntry().HealthCheck hcCfg := entry.RawEntry().HealthCheck
hcCfg.Timeout = idleWakerCheckTimeout hcCfg.Timeout = idleWakerCheckTimeout
@ -46,8 +46,8 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
rp: rp, rp: rp,
stream: stream, stream: stream,
} }
task := parent.Subtask("idlewatcher")
watcher, err := registerWatcher(providerSubTask, entry, waker) watcher, err := registerWatcher(task, entry, waker)
if err != nil { if err != nil {
return nil, E.Errorf("register watcher: %w", err) return nil, E.Errorf("register watcher: %w", err)
} }
@ -63,7 +63,7 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
if common.PrometheusEnabled { if common.PrometheusEnabled {
m := metrics.GetServiceMetrics() m := metrics.GetServiceMetrics()
fqn := providerSubTask.Parent().Name() + "/" + entry.TargetName() fqn := parent.Name() + "/" + entry.TargetName()
waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn)) waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn))
waker.metric.Set(float64(watcher.Status())) waker.metric.Set(float64(watcher.Status()))
} }
@ -71,19 +71,18 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
} }
// lifetime should follow route provider. // lifetime should follow route provider.
func NewHTTPWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(providerSubTask, entry, rp, nil) return newWaker(parent, entry, rp, nil)
} }
func NewStreamWaker(providerSubTask *task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) { func NewStreamWaker(parent task.Parent, entry route.Entry, stream net.Stream) (Waker, E.Error) {
return newWaker(providerSubTask, entry, nil, stream) return newWaker(parent, entry, nil, stream)
} }
// Start implements health.HealthMonitor. // Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask *task.Task) E.Error { func (w *Watcher) Start(parent task.Parent) E.Error {
routeSubTask.Finish("ignored") w.task.OnCancel("route_cleanup", func() {
w.task.OnCancel("stop route and cleanup", func() { parent.Finish(w.task.FinishCause())
routeSubTask.Parent().Finish(w.task.FinishCause())
if w.metric != nil { if w.metric != nil {
w.metric.Reset() w.metric.Reset()
} }
@ -91,6 +90,11 @@ func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
return nil return nil
} }
// Task implements health.HealthMonitor.
func (w *Watcher) Task() *task.Task {
return w.task
}
// Finish implements health.HealthMonitor. // Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) { func (w *Watcher) Finish(reason any) {
if w.stream != nil { if w.stream != nil {

View file

@ -51,7 +51,7 @@ var (
const dockerReqTimeout = 3 * time.Second const dockerReqTimeout = 3 * time.Second
func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) { func registerWatcher(watcherTask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
cfg := entry.IdlewatcherConfig() cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 { if cfg.IdleTimeout == 0 {
@ -67,7 +67,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
w.Config = cfg w.Config = cfg
w.waker = waker w.waker = waker
w.resetIdleTimer() w.resetIdleTimer()
providerSubtask.Finish("used existing watcher") watcherTask.Finish("used existing watcher")
return w, nil return w, nil
} }
@ -81,7 +81,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
Config: cfg, Config: cfg,
waker: waker, waker: waker,
client: client, client: client,
task: providerSubtask, task: watcherTask,
ticker: time.NewTicker(cfg.IdleTimeout), ticker: time.NewTicker(cfg.IdleTimeout),
} }
w.stopByMethod = w.getStopCallback() w.stopByMethod = w.getStopCallback()
@ -210,8 +210,7 @@ func (w *Watcher) resetIdleTimer() {
} }
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) { func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
eventTask = w.task.Subtask("docker event watcher") eventCh, errCh = dockerWatcher.EventsWithOptions(w.Task().Context(), watcher.DockerListOptions{
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
Filters: watcher.NewDockerFilter( Filters: watcher.NewDockerFilter(
watcher.DockerFilterContainer, watcher.DockerFilterContainer,
watcher.DockerFilterContainerNameID(w.ContainerID), watcher.DockerFilterContainerNameID(w.ContainerID),

View file

@ -54,7 +54,7 @@ func SetMiddlewares(mws []map[string]any) error {
return nil return nil
} }
func SetAccessLogger(parent *task.Task, cfg *accesslog.Config) (err error) { func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
epAccessLoggerMu.Lock() epAccessLoggerMu.Lock()
defer epAccessLoggerMu.Unlock() defer epAccessLoggerMu.Unlock()

View file

@ -50,10 +50,12 @@ func Join(errors ...error) Error {
if n == 0 { if n == 0 {
return nil return nil
} }
errs := make([]error, 0, n) errs := make([]error, n)
i := 0
for _, err := range errors { for _, err := range errors {
if err != nil { if err != nil {
errs = append(errs, err) errs[i] = err
i++
} }
} }
return &nestedError{Extras: errs} return &nestedError{Extras: errs}

View file

@ -22,6 +22,7 @@ type (
Path string `json:"path" validate:"required"` Path string `json:"path" validate:"required"`
Filters Filters `json:"filters"` Filters Filters `json:"filters"`
Fields Fields `json:"fields"` Fields Fields `json:"fields"`
// Retention *Retention
} }
) )
@ -31,7 +32,7 @@ var (
FormatJSON Format = "json" FormatJSON Format = "json"
) )
const DefaultBufferSize = 100 const DefaultBufferSize = 64 * 1024 // 64KB
func DefaultConfig() *Config { func DefaultConfig() *Config {
return &Config{ return &Config{

View file

@ -0,0 +1,22 @@
package accesslog
import (
"fmt"
"os"
"sync"
"github.com/yusing/go-proxy/internal/task"
)
type File struct {
*os.File
sync.Mutex
}
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return nil, fmt.Errorf("access log open error: %w", err)
}
return NewAccessLogger(parent, &File{File: f}, cfg), nil
}

View file

@ -7,18 +7,17 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"time"
) )
type ( type (
CommonFormatter struct { CommonFormatter struct {
cfg *Fields cfg *Fields
} GetTimeNow func() time.Time // for testing purposes only
CombinedFormatter struct {
CommonFormatter
}
JSONFormatter struct {
CommonFormatter
} }
CombinedFormatter CommonFormatter
JSONFormatter CommonFormatter
JSONLogEntry struct { JSONLogEntry struct {
Time string `json:"time"` Time string `json:"time"`
IP string `json:"ip"` IP string `json:"ip"`
@ -39,6 +38,8 @@ type (
} }
) )
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
func scheme(req *http.Request) string { func scheme(req *http.Request) string {
if req.TLS != nil { if req.TLS != nil {
return "https" return "https"
@ -62,7 +63,7 @@ func clientIP(req *http.Request) string {
return req.RemoteAddr return req.RemoteAddr
} }
func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query()) query := f.cfg.Query.ProcessQuery(req.URL.Query())
line.WriteString(req.Host) line.WriteString(req.Host)
@ -71,7 +72,7 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
line.WriteString(clientIP(req)) line.WriteString(clientIP(req))
line.WriteString(" - - [") line.WriteString(" - - [")
line.WriteString(timeNow()) line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
line.WriteString("] \"") line.WriteString("] \"")
line.WriteString(req.Method) line.WriteString(req.Method)
@ -86,8 +87,8 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
line.WriteString(strconv.FormatInt(res.ContentLength, 10)) line.WriteString(strconv.FormatInt(res.ContentLength, 10))
} }
func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
f.CommonFormatter.Format(line, req, res) (*CommonFormatter)(f).Format(line, req, res)
line.WriteString(" \"") line.WriteString(" \"")
line.WriteString(req.Referer()) line.WriteString(req.Referer())
line.WriteString("\" \"") line.WriteString("\" \"")
@ -95,14 +96,14 @@ func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *ht
line.WriteRune('"') line.WriteRune('"')
} }
func (f JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) { func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query()) query := f.cfg.Query.ProcessQuery(req.URL.Query())
headers := f.cfg.Headers.ProcessHeaders(req.Header) headers := f.cfg.Headers.ProcessHeaders(req.Header)
headers.Del("Cookie") headers.Del("Cookie")
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies()) cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
entry := JSONLogEntry{ entry := JSONLogEntry{
Time: timeNow(), Time: f.GetTimeNow().Format(LogTimeFormat),
IP: clientIP(req), IP: clientIP(req),
Method: req.Method, Method: req.Method,
Scheme: scheme(req), Scheme: scheme(req),

View file

@ -9,6 +9,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor" "github.com/yusing/go-proxy/internal/watcher/health/monitor"
@ -52,10 +53,13 @@ func New(cfg *Config) *LoadBalancer {
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error { func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
lb.startTime = time.Now() lb.startTime = time.Now()
lb.task = routeSubtask lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
lb.task.OnFinished("loadbalancer cleanup", func() { parent.OnCancel("lb_remove_route", func() {
routes.DeleteHTTPRoute(lb.Link)
})
lb.task.OnFinished("cleanup", func() {
if lb.impl != nil { if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) { lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v) lb.impl.OnRemoveServer(v)
@ -66,6 +70,11 @@ func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
return nil return nil
} }
// Task implements task.TaskStarter.
func (lb *LoadBalancer) Task() *task.Task {
return lb.task
}
// Finish implements task.TaskFinisher. // Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason any) { func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason) lb.task.Finish(reason)

View file

@ -32,10 +32,10 @@ func setup() {
return return
} }
task := task.GlobalTask("error page") t := task.RootTask("error_page", true)
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath) dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
loadContent() loadContent()
go watchDir(task) go watchDir()
} }
func GetStaticFile(filename string) ([]byte, bool) { func GetStaticFile(filename string) ([]byte, bool) {
@ -73,11 +73,11 @@ func loadContent() {
} }
} }
func watchDir(task *task.Task) { func watchDir() {
eventCh, errCh := dirWatcher.Events(task.Context()) eventCh, errCh := dirWatcher.Events(task.RootContext())
for { for {
select { select {
case <-task.Context().Done(): case <-task.RootContextCanceled():
return return
case event, ok := <-eventCh: case event, ok := <-eventCh:
if !ok { if !ok {

View file

@ -24,8 +24,6 @@ type Server struct {
httpsStarted bool httpsStarted bool
startTime time.Time startTime time.Time
task *task.Task
l zerolog.Logger l zerolog.Logger
} }
@ -76,7 +74,6 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider, CertProvider: opt.CertProvider,
http: httpSer, http: httpSer,
https: httpsSer, https: httpsSer,
task: task.GlobalTask(opt.Name + " server"),
l: logger, l: logger,
} }
} }
@ -108,7 +105,7 @@ func (s *Server) Start() {
s.l.Info().Str("addr", s.https.Addr).Msgf("server started") s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
} }
s.task.OnFinished("stop server", s.stop) task.OnProgramExit("server."+s.Name+".stop", s.stop)
} }
func (s *Server) stop() { func (s *Server) stop() {
@ -117,12 +114,12 @@ func (s *Server) stop() {
} }
if s.http != nil && s.httpStarted { if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(s.task.Context())) s.handleErr("http", s.http.Shutdown(task.RootContext()))
s.httpStarted = false s.httpStarted = false
} }
if s.https != nil && s.httpsStarted { if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(s.task.Context())) s.handleErr("https", s.https.Shutdown(task.RootContext()))
s.httpsStarted = false s.httpsStarted = false
} }
} }

View file

@ -35,7 +35,7 @@ var (
const dispatchErr = "notification dispatch error" const dispatchErr = "notification dispatch error"
func StartNotifDispatcher(parent *task.Task) *Dispatcher { func StartNotifDispatcher(parent task.Parent) *Dispatcher {
dispatcher = &Dispatcher{ dispatcher = &Dispatcher{
task: parent.Subtask("notification"), task: parent.Subtask("notification"),
logCh: make(chan *LogMessage), logCh: make(chan *LogMessage),

View file

@ -73,19 +73,17 @@ func (r *HTTPRoute) String() string {
return r.TargetName() return r.TargetName()
} }
// Start implements*task.TaskStarter. // Start implements task.TaskStarter.
func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error { func (r *HTTPRoute) Start(parent task.Parent) E.Error {
if entry.ShouldNotServe(r) { if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil return nil
} }
r.task = providerSubtask r.task = parent.Subtask("http."+r.TargetName(), false)
switch { switch {
case entry.UseIdleWatcher(r): case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName()) waker, err := idlewatcher.NewHTTPWaker(r.task, r.ReverseProxyEntry, r.rp)
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
if err != nil { if err != nil {
r.task.Finish(err) r.task.Finish(err)
return err return err
@ -98,7 +96,7 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
if err == nil { if err == nil {
fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck) fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck)
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback) r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
r.task.OnCancel("close docker client", client.Close) r.task.OnCancel("close_docker_client", client.Close)
} }
} }
if r.HealthMon == nil { if r.HealthMon == nil {
@ -137,29 +135,32 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
} }
if r.HealthMon != nil { if r.HealthMon != nil {
healthMonTask := r.task.Subtask("health monitor") if err := r.HealthMon.Start(r.task); err != nil {
if err := r.HealthMon.Start(healthMonTask); err != nil {
E.LogWarn("health monitor error", err, &r.l) E.LogWarn("health monitor error", err, &r.l)
healthMonTask.Finish(err)
} }
} }
if entry.UseLoadBalance(r) { if entry.UseLoadBalance(r) {
r.addToLoadBalancer() r.addToLoadBalancer(parent)
} else { } else {
routes.SetHTTPRoute(r.TargetName(), r) routes.SetHTTPRoute(r.TargetName(), r)
r.task.OnFinished("remove from route table", func() { r.task.OnFinished("entrypoint_remove_route", func() {
routes.DeleteHTTPRoute(r.TargetName()) routes.DeleteHTTPRoute(r.TargetName())
}) })
} }
if common.PrometheusEnabled { if common.PrometheusEnabled {
r.task.OnFinished("unreg metrics", r.rp.UnregisterMetrics) r.task.OnFinished("metrics_cleanup", r.rp.UnregisterMetrics)
} }
return nil return nil
} }
// Finish implements*task.TaskFinisher. // Task implements task.TaskStarter.
func (r *HTTPRoute) Task() *task.Task {
return r.task
}
// Finish implements task.TaskFinisher.
func (r *HTTPRoute) Finish(reason any) { func (r *HTTPRoute) Finish(reason any) {
r.task.Finish(reason) r.task.Finish(reason)
} }
@ -168,7 +169,7 @@ func (r *HTTPRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.handler.ServeHTTP(w, req) r.handler.ServeHTTP(w, req)
} }
func (r *HTTPRoute) addToLoadBalancer() { func (r *HTTPRoute) addToLoadBalancer(parent task.Parent) {
var lb *loadbalancer.LoadBalancer var lb *loadbalancer.LoadBalancer
cfg := r.Raw.LoadBalance cfg := r.Raw.LoadBalance
l, ok := routes.GetHTTPRoute(cfg.Link) l, ok := routes.GetHTTPRoute(cfg.Link)
@ -182,11 +183,7 @@ func (r *HTTPRoute) addToLoadBalancer() {
} }
} else { } else {
lb = loadbalancer.New(cfg) lb = loadbalancer.New(cfg)
lbTask := r.task.Parent().Subtask("loadbalancer " + cfg.Link) if err := lb.Start(parent); err != nil {
lbTask.OnCancel("remove lb from routes", func() {
routes.DeleteHTTPRoute(cfg.Link)
})
if err := lb.Start(lbTask); err != nil {
panic(err) // should always return nil panic(err) // should always return nil
} }
linked = &HTTPRoute{ linked = &HTTPRoute{
@ -203,9 +200,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
routes.SetHTTPRoute(cfg.Link, linked) routes.SetHTTPRoute(cfg.Link, linked)
} }
r.loadBalancer = lb r.loadBalancer = lb
r.server = loadbalance.NewServer(r.task.String(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon) r.server = loadbalance.NewServer(r.task.Name(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server) lb.AddServer(r.server)
r.task.OnCancel("remove server from lb", func() { r.task.OnCancel("lb_remove_server", func() {
lb.RemoveServer(r.server) lb.RemoveServer(r.server)
}) })
} }

View file

@ -28,7 +28,7 @@ func (p *Provider) newEventHandler() *EventHandler {
} }
} }
func (handler *EventHandler) Handle(parent *task.Task, events []watcher.Event) { func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) {
oldRoutes := handler.provider.routes oldRoutes := handler.provider.routes
newRoutes, err := handler.provider.loadRoutesImpl() newRoutes, err := handler.provider.loadRoutesImpl()
if err != nil { if err != nil {
@ -97,7 +97,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
return false return false
} }
func (handler *EventHandler) Add(parent *task.Task, route *route.Route) { func (handler *EventHandler) Add(parent task.Parent, route *route.Route) {
err := handler.provider.startRoute(parent, route) err := handler.provider.startRoute(parent, route)
if err != nil { if err != nil {
handler.errs.Add(err.Subject("add")) handler.errs.Add(err.Subject("add"))
@ -112,7 +112,7 @@ func (handler *EventHandler) Remove(route *route.Route) {
handler.removed.Adds(route.Entry.Alias) handler.removed.Adds(route.Entry.Alias)
} }
func (handler *EventHandler) Update(parent *task.Task, oldRoute *route.Route, newRoute *route.Route) { func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) {
oldRoute.Finish("route update") oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute) err := handler.provider.startRoute(parent, newRoute)
if err != nil { if err != nil {

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"path" "path"
"strings"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -60,7 +61,7 @@ func NewFileProvider(filename string) (p *Provider, err error) {
if name == "" { if name == "" {
return nil, ErrEmptyProviderName return nil, ErrEmptyProviderName
} }
p = newProvider(name, ProviderTypeFile) p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile)
p.ProviderImpl, err = FileProviderImpl(filename) p.ProviderImpl, err = FileProviderImpl(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -100,46 +101,43 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil return []byte(p.String()), nil
} }
func (p *Provider) startRoute(parent *task.Task, r *R.Route) E.Error { func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error {
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias) err := r.Start(parent)
err := r.Start(subtask)
if err != nil { if err != nil {
p.routes.Delete(r.Entry.Alias)
subtask.Finish(err) // just to ensure
return err.Subject(r.Entry.Alias) return err.Subject(r.Entry.Alias)
} }
p.routes.Store(r.Entry.Alias, r) p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() { r.Task().OnFinished("provider_remove_route", func() {
p.routes.Delete(r.Entry.Alias) p.routes.Delete(r.Entry.Alias)
}) })
return nil return nil
} }
// Start implements*task.TaskStarter. // Start implements*task.TaskStarter.
func (p *Provider) Start(configSubtask *task.Task) E.Error { func (p *Provider) Start(parent task.Parent) E.Error {
// routes and event queue will stop on parent cancel t := parent.Subtask("provider."+p.name, false)
providerTask := configSubtask
// routes and event queue will stop on config reload
errs := p.routes.CollectErrorsParallel( errs := p.routes.CollectErrorsParallel(
func(alias string, r *R.Route) error { func(alias string, r *R.Route) error {
return p.startRoute(providerTask, r) return p.startRoute(t, r)
}) })
eventQueue := events.NewEventQueue( eventQueue := events.NewEventQueue(
providerTask, t.Subtask("event_queue", false),
providerEventFlushInterval, providerEventFlushInterval,
func(flushTask *task.Task, events []events.Event) { func(events []events.Event) {
handler := p.newEventHandler() handler := p.newEventHandler()
// routes' lifetime should follow the provider's lifetime // routes' lifetime should follow the provider's lifetime
handler.Handle(providerTask, events) handler.Handle(t, events)
handler.Log() handler.Log()
flushTask.Finish("events flushed")
}, },
func(err E.Error) { func(err E.Error) {
E.LogError("event error", err, p.Logger()) E.LogError("event error", err, p.Logger())
}, },
) )
eventQueue.Start(p.watcher.Events(providerTask.Context())) eventQueue.Start(p.watcher.Events(t.Context()))
if err := E.Join(errs...); err != nil { if err := E.Join(errs...); err != nil {
return err.Subject(p.String()) return err.Subject(p.String())

View file

@ -47,20 +47,22 @@ func (r *StreamRoute) String() string {
return "stream " + r.TargetName() return "stream " + r.TargetName()
} }
// Start implements*task.TaskStarter. // Start implements task.TaskStarter.
func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error { func (r *StreamRoute) Start(parent task.Parent) E.Error {
if entry.ShouldNotServe(r) { if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil return nil
} }
r.task = providerSubtask r.task = parent.Subtask("stream." + r.TargetName())
r.Stream = NewStream(r) r.Stream = NewStream(r)
parent.OnCancel("finish", func() {
r.task.Finish(nil)
})
switch { switch {
case entry.UseIdleWatcher(r): case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName()) waker, err := idlewatcher.NewStreamWaker(r.task, r.StreamEntry, r.Stream)
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
if err != nil { if err != nil {
r.task.Finish(err) r.task.Finish(err)
return err return err
@ -73,7 +75,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
if err == nil { if err == nil {
fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck) fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck)
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback) r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
r.task.OnCancel("close docker client", client.Close) r.task.OnCancel("close_docker_client", client.Close)
} }
} }
if r.HealthMon == nil { if r.HealthMon == nil {
@ -86,7 +88,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
return E.From(err) return E.From(err)
} }
r.task.OnFinished("close stream", func() { r.task.OnFinished("close_stream", func() {
if err := r.Stream.Close(); err != nil { if err := r.Stream.Close(); err != nil {
E.LogError("close stream failed", err, &r.l) E.LogError("close stream failed", err, &r.l)
} }
@ -97,22 +99,26 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
Msg("listening") Msg("listening")
if r.HealthMon != nil { if r.HealthMon != nil {
healthMonTask := r.task.Subtask("health monitor") if err := r.HealthMon.Start(r.task); err != nil {
if err := r.HealthMon.Start(healthMonTask); err != nil {
E.LogWarn("health monitor error", err, &r.l) E.LogWarn("health monitor error", err, &r.l)
healthMonTask.Finish(err)
} }
} }
go r.acceptConnections() go r.acceptConnections()
routes.SetStreamRoute(r.TargetName(), r) routes.SetStreamRoute(r.TargetName(), r)
r.task.OnFinished("remove from route table", func() { r.task.OnFinished("entrypoint_remove_route", func() {
routes.DeleteStreamRoute(r.TargetName()) routes.DeleteStreamRoute(r.TargetName())
}) })
return nil return nil
} }
// Task implements task.TaskStarter.
func (r *StreamRoute) Task() *task.Task {
return r.task
}
// Finish implements task.TaskFinisher.
func (r *StreamRoute) Finish(reason any) { func (r *StreamRoute) Finish(reason any) {
r.task.Finish(reason) r.task.Finish(reason)
} }

View file

@ -95,7 +95,7 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
return fmt.Errorf("unexpected listener type: %T", stream) return fmt.Errorf("unexpected listener type: %T", stream)
} }
case io.ReadWriteCloser: case io.ReadWriteCloser:
stream.task.OnCancel("close conn", func() { conn.Close() }) stream.task.OnCancel("close_conn", func() { conn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout} dialer := &net.Dialer{Timeout: streamDialTimeout}
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())

View file

@ -4,355 +4,194 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strings" "runtime/debug"
"sync" "sync"
"time"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils/strutils"
F "github.com/yusing/go-proxy/internal/utils/functional"
) )
var globalTask = createGlobalTask()
func createGlobalTask() (t *Task) {
t = new(Task)
t.name = "root"
t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.subtasks = F.NewSet[*Task]()
return
}
func testResetGlobalTask() {
globalTask = createGlobalTask()
}
type ( type (
TaskStarter interface { TaskStarter interface {
// Start starts the object that implements TaskStarter, // Start starts the object that implements TaskStarter,
// and returns an error if it fails to start. // and returns an error if it fails to start.
// //
// The task passed must be a subtask of the caller task.
//
// callerSubtask.Finish must be called when start fails or the object is finished. // callerSubtask.Finish must be called when start fails or the object is finished.
Start(callerSubtask *Task) E.Error Start(parent Parent) E.Error
Task() *Task
} }
TaskFinisher interface { TaskFinisher interface {
// Finish marks the task as finished and cancel its context.
//
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
// of the task to finish.
//
// Note that it will also cancel all subtasks.
Finish(reason any) Finish(reason any)
} }
// Task controls objects' lifetime. // Task controls objects' lifetime.
// //
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface. // Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
// //
// When passing a Task object to another function,
// it must be a sub-Task of the current Task,
// in name of "`currentTaskName`Subtask"
//
// Use Task.Finish to stop all subtasks of the Task. // Use Task.Finish to stop all subtasks of the Task.
Task struct { Task struct {
name string
children sync.WaitGroup
onFinished sync.WaitGroup
finished chan struct{}
ctx context.Context ctx context.Context
cancel context.CancelCauseFunc cancel context.CancelCauseFunc
parent *Task once sync.Once
subtasks F.Set[*Task] }
subTasksWg sync.WaitGroup Parent interface {
Context() context.Context
name string Subtask(name string, needFinish ...bool) *Task
Name() string
OnFinishedFuncs []func() Finish(reason any)
OnFinishedMu sync.Mutex OnCancel(name string, f func())
onFinishedWg sync.WaitGroup
finishOnce sync.Once
} }
) )
var (
ErrProgramExiting = errors.New("program exiting")
ErrTaskCanceled = errors.New("task canceled")
logger = logging.With().Str("module", "task").Logger()
)
// GlobalTask returns a new Task with the given name, derived from the global context.
func GlobalTask(format string, args ...any) *Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return globalTask.Subtask(format)
}
// DebugTaskMap returns a map[string]any representation of the global task tree.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func DebugTaskMap() map[string]any {
return globalTask.serialize()
}
// CancelGlobalContext cancels the global task context, which will cause all tasks
// created to be canceled. This should be called before exiting the program
// to ensure that all tasks are properly cleaned up.
func CancelGlobalContext() {
globalTask.cancel(ErrProgramExiting)
}
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) (err error) {
done := make(chan struct{})
after := time.After(timeout)
go func() {
globalTask.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
return context.DeadlineExceeded
}
}
}
func (t *Task) trace(msg string) {
logger.Trace().Str("name", t.name).Msg(msg)
}
// Name returns the name of the task.
func (t *Task) Name() string {
if !common.IsTrace {
return t.name
}
parts := strings.Split(t.name, " > ")
return parts[len(parts)-1]
}
// String returns the name of the task.
func (t *Task) String() string {
return t.name
}
// Context returns the context associated with the task. This context is
// canceled when Finish of the task is called, or parent task is canceled.
func (t *Task) Context() context.Context { func (t *Task) Context() context.Context {
return t.ctx return t.ctx
} }
func (t *Task) Finished() <-chan struct{} {
return t.finished
}
// FinishCause returns the reason / error that caused the task to be finished. // FinishCause returns the reason / error that caused the task to be finished.
func (t *Task) FinishCause() error { func (t *Task) FinishCause() error {
cause := context.Cause(t.ctx) return context.Cause(t.ctx)
if cause == nil {
return t.ctx.Err()
}
return cause
} }
// Parent returns the parent task of the current task. // OnFinished calls fn when the task is canceled and all subtasks are finished.
func (t *Task) Parent() *Task {
return t.parent
}
func (t *Task) runAllOnFinished(onCompTask *Task) {
<-t.ctx.Done()
t.WaitSubTasks()
for _, OnFinishedFunc := range t.OnFinishedFuncs {
OnFinishedFunc()
t.onFinishedWg.Done()
}
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
}
// OnFinished calls fn when all subtasks are finished.
// //
// It cannot be called after Finish or Wait is called. // It should not be called after Finish is called.
func (t *Task) OnFinished(about string, fn func()) { func (t *Task) OnFinished(about string, fn func()) {
if t.parent == globalTask { t.onCancel(about, fn, true)
t.OnCancel(about, fn)
return
}
t.onFinishedWg.Add(1)
t.OnFinishedMu.Lock()
defer t.OnFinishedMu.Unlock()
if t.OnFinishedFuncs == nil {
onCompTask := GlobalTask(t.name + " > OnFinished > " + about)
go t.runAllOnFinished(onCompTask)
}
idx := len(t.OnFinishedFuncs)
wrapped := func() {
defer func() {
if err := recover(); err != nil {
logger.Error().
Str("name", t.name).
Interface("err", err).
Msg("panic in " + about)
}
}()
fn()
logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about)
}
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
} }
// OnCancel calls fn when the task is canceled. // OnCancel calls fn when the task is canceled.
// //
// It cannot be called after Finish or Wait is called. // It should not be called after Finish is called.
func (t *Task) OnCancel(about string, fn func()) { func (t *Task) OnCancel(about string, fn func()) {
onCompTask := GlobalTask(t.name + " > OnFinished") t.onCancel(about, fn, false)
}
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
t.onFinished.Add(1)
go func() { go func() {
<-t.ctx.Done() <-t.ctx.Done()
fn() if waitSubTasks {
onCompTask.Finish("done") t.children.Wait()
t.trace("onCancel done: " + about) }
t.invokeWithRecover(fn, about)
t.onFinished.Done()
}() }()
} }
// Finish marks the task as finished and cancel its context. // Finish cancel all subtasks and wait for them to finish,
// // then marks the task as finished, with the given reason (if any).
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
// of the task to finish.
//
// Note that it will also cancel all subtasks.
func (t *Task) Finish(reason any) { func (t *Task) Finish(reason any) {
var format string select {
switch reason.(type) { case <-t.finished:
case error: return
format = "%w"
case string, fmt.Stringer:
format = "%s"
default: default:
format = "%v" t.once.Do(func() {
t.finish(reason)
})
}
}
func (t *Task) finish(reason any) {
t.cancel(fmtCause(reason))
t.children.Wait()
t.onFinished.Wait()
if t.finished != nil {
close(t.finished)
}
logger.Trace().Msg("task " + t.name + " finished")
}
func fmtCause(cause any) error {
switch cause := cause.(type) {
case nil:
return nil
case error:
return cause
case string:
return errors.New(cause)
default:
return fmt.Errorf("%v", cause)
} }
t.finishOnce.Do(func() {
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
})
t.Wait()
} }
// Subtask returns a new subtask with the given name, derived from the parent's context. // Subtask returns a new subtask with the given name, derived from the parent's context.
// //
// If the parent's context is already canceled, the returned subtask will be canceled immediately. // This should not be called after Finish is called.
// func (t *Task) Subtask(name string, needFinish ...bool) *Task {
// This should not be called after Finish, Wait, or WaitSubTasks is called. nf := len(needFinish) == 0 || needFinish[0]
func (t *Task) Subtask(name string) *Task {
ctx, cancel := context.WithCancelCause(t.ctx)
return t.newSubTask(ctx, cancel, name)
}
func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task { ctx, cancel := context.WithCancelCause(t.ctx)
parent := t child := &Task{
if common.IsTrace { finished: make(chan struct{}, 1),
name = parent.name + " > " + name
}
subtask := &Task{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
name: name,
parent: parent,
subtasks: F.NewSet[*Task](),
} }
parent.subTasksWg.Add(1) if t != root {
parent.subtasks.Add(subtask) child.name = t.name + "." + name
if common.IsTrace { allTasks.Add(child)
subtask.trace("started") } else {
child.name = name
}
allTasksWg.Add(1)
t.children.Add(1)
if !nf {
go func() { go func() {
subtask.Wait() <-child.ctx.Done()
subtask.trace("finished: " + subtask.FinishCause().Error()) child.Finish(nil)
}() }()
} }
go func() { go func() {
subtask.Wait() <-child.finished
parent.subtasks.Remove(subtask) allTasksWg.Done()
parent.subTasksWg.Done() t.children.Done()
allTasks.Remove(child)
}() }()
return subtask
logger.Trace().Msg("task " + child.name + " started")
return child
} }
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish. // Name returns the name of the task without parent names.
// func (t *Task) Name() string {
// It must be called only after Finish is called. parts := strutils.SplitRune(t.name, '.')
func (t *Task) Wait() { return parts[len(parts)-1]
<-t.ctx.Done()
t.WaitSubTasks()
t.onFinishedWg.Wait()
} }
// WaitSubTasks waits for all subtasks of the task to finish. // String returns the full name of the task.
// func (t *Task) String() string {
// No more subtasks can be added after this call. return t.name
//
// It can be called before Finish is called.
func (t *Task) WaitSubTasks() {
t.subTasksWg.Wait()
} }
// tree returns a string representation of the task tree, with the given // MarshalText implements encoding.TextMarshaler.
// prefix prepended to each line. The prefix is used to indent the tree, func (t *Task) MarshalText() ([]byte, error) {
// and should be a string of spaces or a similar separator. return []byte(t.name), nil
//
// The resulting string is suitable for printing to the console, and can be
// used to debug the task tree.
//
// The tree is traversed in a depth-first manner, with each task's name and
// line number (if available) printed on a separate line. The line number is
// only printed if the task was created with a non-empty line argument.
//
// The returned string is not guaranteed to be stable, and may change between
// runs of the program. It is intended for debugging purposes only.
func (t *Task) tree(prefix ...string) string {
var sb strings.Builder
var pre string
if len(prefix) > 0 {
pre = prefix[0]
sb.WriteString(pre + "- ")
}
sb.WriteString(t.Name() + "\n")
t.subtasks.RangeAll(func(subtask *Task) {
sb.WriteString(subtask.tree(pre + " "))
})
return sb.String()
} }
// serialize returns a map[string]any representation of the task tree. func (t *Task) invokeWithRecover(fn func(), caller string) {
// defer func() {
// The map contains the following keys: if err := recover(); err != nil {
// - name: the name of the task logger.Error().
// - subtasks: a slice of maps, each representing a subtask Interface("err", err).
// Msg("panic in task " + t.name + "." + caller)
// The subtask maps contain the same keys, recursively. if common.IsDebug {
// panic(string(debug.Stack()))
// The returned map is suitable for encoding to JSON, and can be used }
// to debug the task tree. }
// }()
// The returned map is not guaranteed to be stable, and may change fn()
// between runs of the program. It is intended for debugging purposes
// only.
func (t *Task) serialize() map[string]any {
m := make(map[string]any)
parts := strings.Split(t.name, " > ")
m["name"] = parts[len(parts)-1]
if t.subtasks.Size() > 0 {
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
t.subtasks.RangeAll(func(subtask *Task) {
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
})
}
return m
} }

View file

@ -2,132 +2,112 @@ package task
import ( import (
"context" "context"
"sync/atomic" "sync"
"testing" "testing"
"time" "time"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
const ( func testTask() *Task {
rootTaskName = "root-task" return RootTask("test", false)
subTaskName = "subtask"
)
func TestTaskCreation(t *testing.T) {
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
ExpectEqual(t, rootTaskName, rootTask.Name())
ExpectEqual(t, subTaskName, subTask.Name())
} }
func TestTaskCancellation(t *testing.T) { func TestChildTaskCancellation(t *testing.T) {
subTaskDone := make(chan struct{}) t.Cleanup(testCleanup)
rootTask := GlobalTask(rootTaskName) parent := testTask()
subTask := rootTask.Subtask(subTaskName) child := parent.Subtask("")
go func() { go func() {
subTask.Wait() defer child.Finish(nil)
close(subTaskDone) for {
select {
case <-child.Context().Done():
return
default:
continue
}
}
}() }()
go rootTask.Finish(nil) parent.cancel(nil) // should also cancel child
select { select {
case <-subTaskDone: case <-child.Finished():
err := subTask.Context().Err() ExpectError(t, context.Canceled, child.Context().Err())
ExpectError(t, context.Canceled, err) default:
cause := context.Cause(subTask.Context())
ExpectError(t, ErrTaskCanceled, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected") t.Fatal("subTask context was not canceled as expected")
} }
} }
func TestOnComplete(t *testing.T) { func TestTaskOnCancelOnFinished(t *testing.T) {
rootTask := GlobalTask(rootTaskName) t.Cleanup(testCleanup)
task := rootTask.Subtask(subTaskName) task := testTask()
var value atomic.Int32 var shouldTrueOnCancel bool
task.OnFinished("set value", func() { var shouldTrueOnFinish bool
value.Store(1234)
task.OnCancel("", func() {
shouldTrueOnCancel = true
}) })
task.OnFinished("", func() {
shouldTrueOnFinish = true
})
ExpectFalse(t, shouldTrueOnFinish)
task.Finish(nil) task.Finish(nil)
ExpectEqual(t, value.Load(), 1234) ExpectTrue(t, shouldTrueOnCancel)
ExpectTrue(t, shouldTrueOnFinish)
} }
func TestGlobalContextWait(t *testing.T) { func TestCommonFlowWithGracefulShutdown(t *testing.T) {
testResetGlobalTask() t.Cleanup(testCleanup)
defer CancelGlobalContext() task := testTask()
rootTask := GlobalTask(rootTaskName) finished := false
finished1, finished2 := false, false task.OnFinished("", func() {
finished = true
subTask1 := rootTask.Subtask(subTaskName)
subTask2 := rootTask.Subtask(subTaskName)
subTask1.OnFinished("", func() {
finished1 = true
})
subTask2.OnFinished("", func() {
finished2 = true
}) })
go func() { go func() {
time.Sleep(500 * time.Millisecond) defer task.Finish(nil)
subTask1.Finish(nil) for {
select {
case <-task.Context().Done():
return
default:
continue
}
}
}() }()
go func() { ExpectNoError(t, GracefulShutdown(1*time.Second))
time.Sleep(500 * time.Millisecond) ExpectTrue(t, finished)
subTask2.Finish(nil)
}()
go func() { <-root.finished
subTask1.Wait() ExpectError(t, context.Canceled, task.Context().Err())
subTask2.Wait() ExpectError(t, ErrProgramExiting, task.FinishCause())
rootTask.Finish(nil)
}()
_ = GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
} }
func TestTimeoutOnGlobalContextWait(t *testing.T) { func TestTimeoutOnGracefulShutdown(t *testing.T) {
testResetGlobalTask() t.Cleanup(testCleanup)
_ = testTask()
rootTask := GlobalTask(rootTaskName) ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
rootTask.Subtask(subTaskName)
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
} }
func TestGlobalContextCancellation(t *testing.T) { func TestFinishMultipleCalls(t *testing.T) {
testResetGlobalTask() t.Cleanup(testCleanup)
task := testTask()
taskDone := make(chan struct{}) var wg sync.WaitGroup
rootTask := GlobalTask(rootTaskName) wg.Add(5)
for range 5 {
go func() { go func() {
rootTask.Wait() defer wg.Done()
close(taskDone) task.Finish(nil)
}() }()
CancelGlobalContext()
select {
case <-taskDone:
err := rootTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(rootTask.Context())
ExpectError(t, ErrProgramExiting, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
} }
wg.Wait()
} }

96
internal/task/utils.go Normal file
View file

@ -0,0 +1,96 @@
package task
import (
"context"
"encoding/json"
"errors"
"slices"
"sync"
"time"
"github.com/yusing/go-proxy/internal/logging"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
var ErrProgramExiting = errors.New("program exiting")
var logger = logging.With().Str("module", "task").Logger()
var root = newRoot()
var allTasks = F.NewSet[*Task]()
var allTasksWg sync.WaitGroup
func testCleanup() {
root = newRoot()
allTasks.Clear()
allTasksWg = sync.WaitGroup{}
}
// RootTask returns a new Task with the given name, derived from the root context.
func RootTask(name string, needFinish bool) *Task {
return root.Subtask(name, needFinish)
}
func newRoot() *Task {
t := &Task{name: "root"}
t.ctx, t.cancel = context.WithCancelCause(context.Background())
return t
}
func RootContext() context.Context {
return root.ctx
}
func RootContextCanceled() <-chan struct{} {
return root.ctx.Done()
}
func OnProgramExit(about string, fn func()) {
root.OnFinished(about, fn)
}
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GracefulShutdown(timeout time.Duration) (err error) {
root.cancel(ErrProgramExiting)
done := make(chan struct{})
after := time.After(timeout)
go func() {
allTasksWg.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
b, err := json.Marshal(DebugTaskList())
if err != nil {
logger.Warn().Err(err).Msg("failed to marshal tasks")
return context.DeadlineExceeded
}
logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
return context.DeadlineExceeded
}
}
}
// DebugTaskList returns list of all tasks.
//
// The returned string is suitable for printing to the console.
func DebugTaskList() []string {
l := make([]string, 0, allTasks.Size())
allTasks.RangeAll(func(t *Task) {
l = append(l, t.name)
})
slices.Sort(l)
return l
}

View file

@ -1,30 +0,0 @@
package utils
import (
"encoding/json"
"sync/atomic"
)
type AtomicValue[T any] struct {
atomic.Value
}
func (a *AtomicValue[T]) Load() T {
return a.Value.Load().(T)
}
func (a *AtomicValue[T]) Store(v T) {
a.Value.Store(v)
}
func (a *AtomicValue[T]) Swap(v T) T {
return a.Value.Swap(v).(T)
}
func (a *AtomicValue[T]) CompareAndSwap(oldV, newV T) bool {
return a.Value.CompareAndSwap(oldV, newV)
}
func (a *AtomicValue[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.Load())
}

View file

@ -0,0 +1,30 @@
package atomic
import (
"encoding/json"
"sync/atomic"
)
type Value[T any] struct {
atomic.Value
}
func (a *Value[T]) Load() T {
return a.Value.Load().(T)
}
func (a *Value[T]) Store(v T) {
a.Value.Store(v)
}
func (a *Value[T]) Swap(v T) T {
return a.Value.Swap(v).(T)
}
func (a *Value[T]) CompareAndSwap(oldV, newV T) bool {
return a.Value.CompareAndSwap(oldV, newV)
}
func (a *Value[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.Load())
}

View file

@ -152,9 +152,10 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
return m.CollectErrors(do) return m.CollectErrors(do)
} }
errs := make([]error, 0) var errs []error
mu := sync.Mutex{} var mu sync.Mutex
wg := sync.WaitGroup{} var wg sync.WaitGroup
m.Range(func(k KT, v VT) bool { m.Range(func(k KT, v VT) bool {
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -171,24 +172,6 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
return errs return errs
} }
// RemoveAll removes all key-value pairs from the map where the value matches the given criteria.
//
// Parameters:
//
// criteria: function to determine whether a value should be removed
//
// Returns:
//
// nothing
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.Range(func(k KT, v VT) bool {
if criteria(v) {
m.Delete(k)
}
return true
})
}
func (m Map[KT, VT]) Has(k KT) bool { func (m Map[KT, VT]) Has(k KT) bool {
_, ok := m.Load(k) _, ok := m.Load(k)
return ok return ok

View file

@ -61,3 +61,7 @@ func (set Set[T]) RangeAllParallel(f func(T)) {
func (set Set[T]) Size() int { func (set Set[T]) Size() int {
return set.m.Size() return set.m.Size()
} }
func (set Set[T]) IsEmpty() bool {
return set.m == nil || set.m.Size() == 0
}

View file

@ -20,10 +20,17 @@ func IgnoreError[Result any](r Result, _ error) Result {
return r return r
} }
func fmtError(err error) string {
if err == nil {
return "<nil>"
}
return ansi.StripANSI(err.Error())
}
func ExpectNoError(t *testing.T, err error) { func ExpectNoError(t *testing.T, err error) {
t.Helper() t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() { if err != nil {
t.Errorf("expected err=nil, got %s", ansi.StripANSI(err.Error())) t.Errorf("expected err=nil, got %s", fmtError(err))
t.FailNow() t.FailNow()
} }
} }
@ -31,7 +38,7 @@ func ExpectNoError(t *testing.T, err error) {
func ExpectError(t *testing.T, expected error, err error) { func ExpectError(t *testing.T, expected error, err error) {
t.Helper() t.Helper()
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Errorf("expected err %s, got %s", expected, ansi.StripANSI(err.Error())) t.Errorf("expected err %s, got %s", expected, fmtError(err))
t.FailNow() t.FailNow()
} }
} }
@ -39,7 +46,7 @@ func ExpectError(t *testing.T, expected error, err error) {
func ExpectError2(t *testing.T, input any, expected error, err error) { func ExpectError2(t *testing.T, input any, expected error, err error) {
t.Helper() t.Helper()
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Errorf("%v: expected err %s, got %s", input, expected, ansi.StripANSI(err.Error())) t.Errorf("%v: expected err %s, got %s", input, expected, fmtError(err))
t.FailNow() t.FailNow()
} }
} }
@ -48,7 +55,7 @@ func ExpectErrorT[T error](t *testing.T, err error) {
t.Helper() t.Helper()
var errAs T var errAs T
if !errors.As(err, &errAs) { if !errors.As(err, &errAs) {
t.Errorf("expected err %T, got %s", errAs, ansi.StripANSI(err.Error())) t.Errorf("expected err %T, got %s", errAs, fmtError(err))
t.FailNow() t.FailNow()
} }
} }

View file

@ -16,8 +16,10 @@ var (
func NewConfigFileWatcher(filename string) Watcher { func NewConfigFileWatcher(filename string) Watcher {
configDirWatcherMu.Lock() configDirWatcherMu.Lock()
defer configDirWatcherMu.Unlock() defer configDirWatcherMu.Unlock()
if configDirWatcher == nil { if configDirWatcher == nil {
configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath) t := task.RootTask("config_dir_watcher", false)
configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath)
} }
return configDirWatcher.Add(filename) return configDirWatcher.Add(filename)
} }

View file

@ -37,7 +37,7 @@ type DirWatcher struct {
// //
// Note that the returned DirWatcher is not ready to use until the goroutine // Note that the returned DirWatcher is not ready to use until the goroutine
// started by NewDirectoryWatcher has finished. // started by NewDirectoryWatcher has finished.
func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher { func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
//! subdirectories are not watched //! subdirectories are not watched
w, err := fsnotify.NewWatcher() w, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
@ -56,7 +56,7 @@ func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
fwMap: F.NewMapOf[string, *fileWatcher](), fwMap: F.NewMapOf[string, *fileWatcher](),
eventCh: make(chan Event), eventCh: make(chan Event),
errCh: make(chan E.Error), errCh: make(chan E.Error),
task: callerSubtask, task: parent.Subtask("dir_watcher(" + dirPath + ")"),
} }
go helper.start() go helper.start()
return helper return helper
@ -80,17 +80,19 @@ func (h *DirWatcher) Add(relPath string) Watcher {
eventCh: make(chan Event), eventCh: make(chan Event),
errCh: make(chan E.Error), errCh: make(chan E.Error),
} }
h.task.OnFinished("close file watcher for "+relPath, func() {
close(s.eventCh)
close(s.errCh)
})
h.fwMap.Store(relPath, s) h.fwMap.Store(relPath, s)
return s return s
} }
func (h *DirWatcher) cleanup() {
h.w.Close()
close(h.eventCh)
close(h.errCh)
h.task.Finish(nil)
}
func (h *DirWatcher) start() { func (h *DirWatcher) start() {
defer close(h.eventCh) defer h.cleanup()
defer h.w.Close()
for { for {
select { select {

View file

@ -1,6 +1,7 @@
package events package events
import ( import (
"runtime/debug"
"time" "time"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
@ -17,7 +18,7 @@ type (
onFlush OnFlushFunc onFlush OnFlushFunc
onError OnErrorFunc onError OnErrorFunc
} }
OnFlushFunc = func(flushTask *task.Task, events []Event) OnFlushFunc = func(events []Event)
OnErrorFunc = func(err E.Error) OnErrorFunc = func(err E.Error)
) )
@ -38,9 +39,9 @@ const eventQueueCapacity = 10
// but the onFlush function can return earlier (e.g. run in another goroutine). // but the onFlush function can return earlier (e.g. run in another goroutine).
// //
// If task is canceled before the flushInterval is reached, the events in queue will be discarded. // If task is canceled before the flushInterval is reached, the events in queue will be discarded.
func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { func NewEventQueue(queueTask *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
return &EventQueue{ return &EventQueue{
task: parent.Subtask("event queue"), task: queueTask,
queue: make([]Event, 0, eventQueueCapacity), queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval), ticker: time.NewTicker(flushInterval),
flushInterval: flushInterval, flushInterval: flushInterval,
@ -50,19 +51,20 @@ func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlu
} }
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
if common.IsProduction { origOnFlush := e.onFlush
origOnFlush := e.onFlush // recover panic in onFlush when in production mode
// recover panic in onFlush when in production mode e.onFlush = func(events []Event) {
e.onFlush = func(flushTask *task.Task, events []Event) { defer func() {
defer func() { if err := recover(); err != nil {
if err := recover(); err != nil { e.onError(E.New("recovered panic in onFlush").
e.onError(E.New("recovered panic in onFlush"). Withf("%v", err).
Withf("%v", err). Subject(e.task.Name()))
Subject(e.task.Parent().String())) if common.IsDebug {
panic(string(debug.Stack()))
} }
}() }
origOnFlush(flushTask, events) }()
} origOnFlush(events)
} }
go func() { go func() {
@ -75,19 +77,24 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
return return
case <-e.ticker.C: case <-e.ticker.C:
if len(e.queue) > 0 { if len(e.queue) > 0 {
flushTask := e.task.Subtask("flush events") // clone -> clear -> flush
queue := e.queue queue := make([]Event, len(e.queue))
e.queue = make([]Event, 0, eventQueueCapacity) copy(queue, e.queue)
go e.onFlush(flushTask, queue)
flushTask.Wait() e.queue = e.queue[:0]
e.onFlush(queue)
} }
e.ticker.Reset(e.flushInterval) e.ticker.Reset(e.flushInterval)
case event, ok := <-eventCh: case event, ok := <-eventCh:
e.queue = append(e.queue, event)
if !ok { if !ok {
return return
} }
case err := <-errCh: e.queue = append(e.queue, event)
case err, ok := <-errCh:
if !ok {
return
}
if err != nil { if err != nil {
e.onError(err) e.onError(err)
} }
@ -95,10 +102,3 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
} }
}() }()
} }
// Wait waits for all events to be flushed and the task to finish.
//
// It is safe to call this method multiple times.
func (e *EventQueue) Wait() {
e.task.Wait()
}

View file

@ -13,7 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -23,9 +23,9 @@ type (
monitor struct { monitor struct {
service string service string
config *health.HealthCheckConfig config *health.HealthCheckConfig
url U.AtomicValue[types.URL] url atomic.Value[types.URL]
status U.AtomicValue[health.Status] status atomic.Value[health.Status]
lastResult *health.HealthCheckResult lastResult *health.HealthCheckResult
lastSeen time.Time lastSeen time.Time
@ -59,10 +59,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (mon *monitor) Start(routeSubtask *task.Task) E.Error { func (mon *monitor) Start(parent task.Parent) E.Error {
mon.service = routeSubtask.Parent().Name()
mon.task = routeSubtask
if mon.config.Interval <= 0 { if mon.config.Interval <= 0 {
return E.From(ErrNegativeInterval) return E.From(ErrNegativeInterval)
} }
@ -71,6 +68,9 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service)) mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service))
} }
mon.service = parent.Name()
mon.task = parent.Subtask("health_monitor")
go func() { go func() {
logger := logging.With().Str("name", mon.service).Logger() logger := logging.With().Str("name", mon.service).Logger()
@ -78,10 +78,10 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
if mon.status.Load() != health.StatusError { if mon.status.Load() != health.StatusError {
mon.status.Store(health.StatusUnknown) mon.status.Store(health.StatusUnknown)
} }
mon.task.Finish(nil)
if mon.metric != nil { if mon.metric != nil {
mon.metric.Reset() mon.metric.Reset()
} }
mon.task.Finish(nil)
}() }()
if err := mon.checkUpdateHealth(); err != nil { if err := mon.checkUpdateHealth(); err != nil {
@ -108,6 +108,11 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
return nil return nil
} }
// Task implements task.TaskStarter.
func (mon *monitor) Task() *task.Task {
return mon.task
}
// Finish implements task.TaskFinisher. // Finish implements task.TaskFinisher.
func (mon *monitor) Finish(reason any) { func (mon *monitor) Finish(reason any) {
mon.task.Finish(reason) mon.task.Finish(reason)