diff --git a/internal/watcher/health/config.go b/internal/watcher/health/config.go index 88169a8..b12c28f 100644 --- a/internal/watcher/health/config.go +++ b/internal/watcher/health/config.go @@ -1,6 +1,7 @@ package health import ( + "context" "time" "github.com/yusing/go-proxy/internal/common" @@ -12,6 +13,8 @@ type HealthCheckConfig struct { UseGet bool `json:"use_get,omitempty"` Interval time.Duration `json:"interval" validate:"omitempty,min=1s"` Timeout time.Duration `json:"timeout" validate:"omitempty,min=1s"` + + BaseContext func() context.Context `json:"-"` } func DefaultHealthConfig() *HealthCheckConfig { diff --git a/internal/watcher/health/monitor/monitor.go b/internal/watcher/health/monitor/monitor.go index c7567b2..14ad054 100644 --- a/internal/watcher/health/monitor/monitor.go +++ b/internal/watcher/health/monitor/monitor.go @@ -87,10 +87,15 @@ func newMonitor(u *url.URL, config *health.HealthCheckConfig, healthCheckFunc He } func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cancel context.CancelFunc) { - if mon.task != nil { - return context.WithTimeoutCause(mon.task.Context(), mon.config.Timeout, errors.New(cause)) + switch { + case mon.config.BaseContext != nil: + ctx = mon.config.BaseContext() + case mon.task != nil: + ctx = mon.task.Context() + default: + ctx = context.Background() } - return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause)) + return context.WithTimeoutCause(ctx, mon.config.Timeout, errors.New(cause)) } // Start implements task.TaskStarter.