diff --git a/cmd/main.go b/cmd/main.go index 413ca53..0dc0ecf 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,16 +1,12 @@ package main import ( - "context" "encoding/json" "io" "log" "net/http" "os" "os/signal" - "reflect" - "runtime" - "strings" "syscall" "time" @@ -20,13 +16,10 @@ import ( "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" - "github.com/yusing/go-proxy/internal/docker" - "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/server" - F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/pkg" ) @@ -39,7 +32,6 @@ func main() { } l := logrus.WithField("module", "main") - onShutdown := F.NewSlice[func()]() if common.IsDebug { logrus.SetLevel(logrus.DebugLevel) @@ -127,9 +119,6 @@ func main() { cfg.StartProxyProviders() cfg.WatchChanges() - onShutdown.Add(docker.CloseAllClients) - onShutdown.Add(cfg.Dispose) - sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM) @@ -137,9 +126,7 @@ func main() { autocert := cfg.GetAutoCertProvider() if autocert != nil { - ctx, cancel := context.WithCancel(context.Background()) - onShutdown.Add(cancel) - if err := autocert.Setup(ctx); err != nil { + if err := autocert.Setup(); err != nil { l.Fatal(err) } } else { @@ -164,55 +151,24 @@ func main() { proxyServer.Start() apiServer.Start() - onShutdown.Add(proxyServer.Stop) - onShutdown.Add(apiServer.Stop) - - go idlewatcher.Start() - onShutdown.Add(idlewatcher.Stop) // wait for signal <-sig // grafully shutdown logrus.Info("shutting down") - done := make(chan struct{}, 1) - currentIdx := 0 - - go func() { - onShutdown.ForEach(func(f func()) { - l.Debugf("waiting for %s to complete...", funcName(f)) - f() - currentIdx++ - l.Debugf("%s done", funcName(f)) - }) - close(done) - }() - - timeout := time.After(time.Duration(cfg.Value().TimeoutShutdown) * time.Second) - select { - case <-done: - logrus.Info("shutdown complete") - case <-timeout: - logrus.Info("timeout waiting for shutdown") - for i := currentIdx; i < onShutdown.Size(); i++ { - l.Warnf("%s() is still running", funcName(onShutdown.Get(i))) - } - } + common.CancelGlobalContext() + common.GlobalContextWait(time.Second * time.Duration(cfg.Value().TimeoutShutdown)) } func prepareDirectory(dir string) { if _, err := os.Stat(dir); os.IsNotExist(err) { - if err = os.MkdirAll(dir, 0755); err != nil { + if err = os.MkdirAll(dir, 0o755); err != nil { logrus.Fatalf("failed to create directory %s: %v", dir, err) } } } -func funcName(f func()) string { - parts := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "/go-proxy/") - return parts[len(parts)-1] -} - func printJSON(obj any) { j, err := E.Check(json.MarshalIndent(obj, "", " ")) if err != nil { diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index ca7380d..011c89f 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -1,7 +1,6 @@ package autocert import ( - "context" "crypto/tls" "crypto/x509" "os" @@ -14,6 +13,7 @@ import ( "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/types" U "github.com/yusing/go-proxy/internal/utils" @@ -136,20 +136,20 @@ func (p *Provider) ShouldRenewOn() time.Time { panic("no certificate available") } -func (p *Provider) ScheduleRenewal(ctx context.Context) { +func (p *Provider) ScheduleRenewal() { if p.GetName() == ProviderLocal { return } - logger.Debug("started renewal scheduler") - defer logger.Debug("renewal scheduler stopped") - ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() + task := common.NewTask("cert renew scheduler") + defer task.Finished() + for { select { - case <-ctx.Done(): + case <-task.Context().Done(): return case <-ticker.C: // check every 5 seconds if err := p.renewIfNeeded(); err.HasError() { diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 2b44f94..ef754c1 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -1,13 +1,12 @@ package autocert import ( - "context" "os" E "github.com/yusing/go-proxy/internal/error" ) -func (p *Provider) Setup(ctx context.Context) (err E.NestedError) { +func (p *Provider) Setup() (err E.NestedError) { if err = p.LoadCert(); err != nil { if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist return err @@ -18,7 +17,7 @@ func (p *Provider) Setup(ctx context.Context) (err E.NestedError) { } } - go p.ScheduleRenewal(ctx) + go p.ScheduleRenewal() for _, expiry := range p.GetExpiries() { logger.Infof("certificate expire on %s", expiry) diff --git a/internal/common/task.go b/internal/common/task.go new file mode 100644 index 0000000..1f4120d --- /dev/null +++ b/internal/common/task.go @@ -0,0 +1,158 @@ +package common + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/puzpuzpuz/xsync/v3" + "github.com/sirupsen/logrus" +) + +var ( + globalCtx, globalCtxCancel = context.WithCancel(context.Background()) + globalCtxWg sync.WaitGroup + globalCtxTraceMap = xsync.NewMapOf[*task, struct{}]() +) + +type ( + Task interface { + Name() string + Context() context.Context + Subtask(usageFmt string, args ...interface{}) Task + SubtaskWithCancel(usageFmt string, args ...interface{}) (Task, context.CancelFunc) + Finished() + } + task struct { + ctx context.Context + subtasks []*task + name string + finished bool + mu sync.Mutex + } +) + +func (t *task) Name() string { + return t.name +} + +func (t *task) Context() context.Context { + return t.ctx +} + +func (t *task) Finished() { + t.mu.Lock() + defer t.mu.Unlock() + + if t.finished { + return + } + t.finished = true + if _, ok := globalCtxTraceMap.Load(t); ok { + globalCtxWg.Done() + globalCtxTraceMap.Delete(t) + } +} + +func (t *task) Subtask(format string, args ...interface{}) Task { + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + t.mu.Lock() + defer t.mu.Unlock() + sub := newSubTask(t.ctx, format) + t.subtasks = append(t.subtasks, sub) + return sub +} + +func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + t.mu.Lock() + defer t.mu.Unlock() + ctx, cancel := context.WithCancel(t.ctx) + sub := newSubTask(ctx, format) + t.subtasks = append(t.subtasks, sub) + return sub, cancel +} + +func (t *task) Tree(prefix ...string) string { + var sb strings.Builder + var pre string + if len(prefix) > 0 { + pre = prefix[0] + } + sb.WriteString(pre) + sb.WriteString(t.Name() + "\n") + for _, sub := range t.subtasks { + if sub.finished { + continue + } + sb.WriteString(sub.Tree(pre + " ")) + } + return sb.String() +} + +func newSubTask(ctx context.Context, name string) *task { + t := &task{ + ctx: ctx, + name: name, + } + globalCtxTraceMap.Store(t, struct{}{}) + globalCtxWg.Add(1) + return t +} + +func NewTask(format string, args ...interface{}) Task { + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + return newSubTask(globalCtx, format) +} + +func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { + subCtx, cancel := context.WithCancel(globalCtx) + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + return newSubTask(subCtx, format), cancel +} + +func GlobalTask(format string, args ...interface{}) Task { + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + return &task{ + ctx: globalCtx, + name: format, + } +} + +func CancelGlobalContext() { + globalCtxCancel() +} + +func GlobalContextWait(timeout time.Duration) { + done := make(chan struct{}) + after := time.After(timeout) + go func() { + globalCtxWg.Wait() + close(done) + }() + for { + select { + case <-done: + return + case <-after: + logrus.Println("Timeout waiting for these tasks to finish:") + globalCtxTraceMap.Range(func(t *task, _ struct{}) bool { + logrus.Println(t.Tree()) + return true + }) + return + } + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 242fe25..444e0f8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,6 @@ package config import ( - "context" "os" "github.com/sirupsen/logrus" @@ -25,10 +24,9 @@ type Config struct { l logrus.FieldLogger - watcher W.Watcher - watcherCtx context.Context - watcherCancel context.CancelFunc - reloadReq chan struct{} + watcher W.Watcher + + reloadReq chan struct{} } var instance *Config @@ -76,14 +74,6 @@ func (cfg *Config) GetAutoCertProvider() *autocert.Provider { return cfg.autocertProvider } -func (cfg *Config) Dispose() { - if cfg.watcherCancel != nil { - cfg.watcherCancel() - cfg.l.Debug("stopped watcher") - } - cfg.stopProviders() -} - func (cfg *Config) Reload() (err E.NestedError) { cfg.stopProviders() err = cfg.load() @@ -96,11 +86,13 @@ func (cfg *Config) StartProxyProviders() { } func (cfg *Config) WatchChanges() { - cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background()) + task := common.NewTask("Config watcher") + defer task.Finished() + go func() { for { select { - case <-cfg.watcherCtx.Done(): + case <-task.Context().Done(): return case <-cfg.reloadReq: if err := cfg.Reload(); err != nil { @@ -110,10 +102,10 @@ func (cfg *Config) WatchChanges() { } }() go func() { - eventCh, errCh := cfg.watcher.Events(cfg.watcherCtx) + eventCh, errCh := cfg.watcher.Events(task.Context()) for { select { - case <-cfg.watcherCtx.Done(): + case <-task.Context().Done(): return case event := <-eventCh: if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed { diff --git a/internal/config/query.go b/internal/config/query.go index 2f275c2..9d04469 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -97,23 +97,31 @@ func (cfg *Config) HomepageConfig() homepage.Config { return hpCfg } -func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { +func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]U.SerializedObject { routes := make(map[string]U.SerializedObject) - cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { - if !r.Started() { - return + if len(typeFilter) == 0 { + typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream} + } + for _, t := range typeFilter { + switch t { + case R.RouteTypeReverseProxy: + R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { + obj, err := U.Serialize(r) + if err != nil { + panic(err) // should not happen + } + routes[alias] = obj + }) + case R.RouteTypeStream: + R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) { + obj, err := U.Serialize(r) + if err != nil { + panic(err) // should not happen + } + routes[alias] = obj + }) } - obj, err := U.Serialize(r) - if err != nil { - cfg.l.Error(err) - return - } - obj["provider"] = p.GetName() - obj["type"] = string(r.Type) - obj["started"] = r.Started() - obj["raw"] = r.Entry - routes[alias] = obj - }) + } return routes } diff --git a/internal/route/constants.go b/internal/route/constants.go deleted file mode 100644 index 06650ef..0000000 --- a/internal/route/constants.go +++ /dev/null @@ -1,9 +0,0 @@ -package route - -import ( - "time" -) - -const ( - streamStopListenTimeout = 1 * time.Second -) diff --git a/internal/route/http.go b/internal/route/http.go index e36a4d5..8654135 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -1,7 +1,6 @@ package route import ( - "context" "errors" "fmt" "net/http" @@ -10,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/api/v1/errorpage" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" @@ -52,6 +52,10 @@ func (rp ReverseProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) rp.ReverseProxy.ServeHTTP(w, r) } +func GetReverseProxies() F.Map[string, *HTTPRoute] { + return httpRoutes +} + func SetFindMuxDomains(domains []string) { if len(domains) == 0 { findMuxFunc = findMuxAnyDomain @@ -91,8 +95,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { } if !entry.HealthCheck.Disabled { r.healthMon = health.NewHTTPHealthMonitor( - context.Background(), - string(entry.Alias), + common.GlobalTask("Reverse proxy "+r.String()), entry.URL, entry.HealthCheck, ) diff --git a/internal/route/stream.go b/internal/route/stream.go index 36c3976..aef8c0f 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -5,14 +5,14 @@ import ( "errors" "fmt" "sync" - "sync/atomic" - "time" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" url "github.com/yusing/go-proxy/internal/net/types" P "github.com/yusing/go-proxy/internal/proxy" PT "github.com/yusing/go-proxy/internal/proxy/fields" + F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -20,16 +20,18 @@ type StreamRoute struct { *P.StreamEntry StreamImpl `json:"-"` - url url.URL - healthMon health.HealthMonitor + HealthMon health.HealthMonitor `json:"health"` + + url url.URL wg sync.WaitGroup - ctx context.Context + task common.Task cancel context.CancelFunc - connCh chan any - started atomic.Bool - l logrus.FieldLogger + connCh chan any + l logrus.FieldLogger + + mu sync.Mutex } type StreamImpl interface { @@ -40,6 +42,12 @@ type StreamImpl interface { String() string } +var streamRoutes = F.NewMapOf[string, *StreamRoute]() + +func GetStreamProxies() F.Map[string, *StreamRoute] { + return streamRoutes +} + func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { // TODO: support non-coherent scheme if !entry.Scheme.IsCoherent() { @@ -60,9 +68,6 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { } else { base.StreamImpl = NewUDPRoute(base) } - if !entry.Healthcheck.Disabled { - base.healthMon = health.NewRawHealthMonitor(base.ctx, string(entry.Alias), url, entry.Healthcheck) - } base.l = logrus.WithField("route", base.StreamImpl) return base, nil } @@ -76,72 +81,71 @@ func (r *StreamRoute) URL() url.URL { } func (r *StreamRoute) Start() E.NestedError { - if r.Port.ProxyPort == PT.NoPort || r.started.Load() { + r.mu.Lock() + defer r.mu.Unlock() + + if r.Port.ProxyPort == PT.NoPort || r.task != nil { return nil } - r.ctx, r.cancel = context.WithCancel(context.Background()) + r.task, r.cancel = common.NewTaskWithCancel(r.String()) r.wg.Wait() if err := r.Setup(); err != nil { return E.FailWith("setup", err) } r.l.Infof("listening on port %d", r.Port.ListeningPort) - r.started.Store(true) r.wg.Add(2) - go r.grAcceptConnections() - go r.grHandleConnections() - if r.healthMon != nil { - r.healthMon.Start() + go r.acceptConnections() + go r.handleConnections() + if !r.Healthcheck.Disabled { + r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck) + r.HealthMon.Start() } + streamRoutes.Store(string(r.Alias), r) return nil } func (r *StreamRoute) Stop() E.NestedError { - if !r.started.Load() { + r.mu.Lock() + defer r.mu.Unlock() + + if r.task == nil { return nil } - r.started.Store(false) - if r.healthMon != nil { - r.healthMon.Stop() + streamRoutes.Delete(string(r.Alias)) + + if r.HealthMon != nil { + r.HealthMon.Stop() + r.HealthMon = nil } r.cancel() r.CloseListeners() - done := make(chan struct{}, 1) - go func() { - r.wg.Wait() - close(done) - }() + r.wg.Wait() + r.task.Finished() - timeout := time.After(streamStopListenTimeout) - for { - select { - case <-done: - r.l.Debug("stopped listening") - return nil - case <-timeout: - return E.FailedWhy("stop", "timed out") - } - } + r.task, r.cancel = nil, nil + + return nil } func (r *StreamRoute) Started() bool { - return r.started.Load() + return r.task != nil } -func (r *StreamRoute) grAcceptConnections() { +func (r *StreamRoute) acceptConnections() { defer r.wg.Done() for { select { - case <-r.ctx.Done(): + case <-r.task.Context().Done(): return default: conn, err := r.Accept() if err != nil { select { - case <-r.ctx.Done(): + case <-r.task.Context().Done(): return default: r.l.Error(err) @@ -153,12 +157,12 @@ func (r *StreamRoute) grAcceptConnections() { } } -func (r *StreamRoute) grHandleConnections() { +func (r *StreamRoute) handleConnections() { defer r.wg.Done() for { select { - case <-r.ctx.Done(): + case <-r.task.Context().Done(): return case conn := <-r.connCh: go func() { diff --git a/internal/route/tcp.go b/internal/route/tcp.go index d5e6621..b7367e1 100755 --- a/internal/route/tcp.go +++ b/internal/route/tcp.go @@ -51,7 +51,7 @@ func (route *TCPRoute) Handle(c any) error { defer clientConn.Close() - ctx, cancel := context.WithTimeout(route.ctx, tcpDialTimeout) + ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) defer cancel() serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) @@ -64,7 +64,7 @@ func (route *TCPRoute) Handle(c any) error { route.mu.Lock() - pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn) + pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) route.pipe = append(route.pipe, pipe) route.mu.Unlock() diff --git a/internal/route/udp.go b/internal/route/udp.go index cec210f..c15630c 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -93,7 +93,7 @@ func (route *UDPRoute) Accept() (any, error) { key, srcConn, dstConn, - U.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), + U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), } route.connMap.Store(key, conn) } diff --git a/internal/watcher/health/http.go b/internal/watcher/health/http.go index f0ca218..c4cd267 100644 --- a/internal/watcher/health/http.go +++ b/internal/watcher/health/http.go @@ -1,11 +1,11 @@ package health import ( - "context" "crypto/tls" "errors" "net/http" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/net/types" ) @@ -15,9 +15,9 @@ type HTTPHealthMonitor struct { pinger *http.Client } -func NewHTTPHealthMonitor(ctx context.Context, name string, url types.URL, config HealthCheckConfig) HealthMonitor { +func NewHTTPHealthMonitor(task common.Task, url types.URL, config HealthCheckConfig) HealthMonitor { mon := new(HTTPHealthMonitor) - mon.monitor = newMonitor(ctx, name, url, &config, mon.checkHealth) + mon.monitor = newMonitor(task, url, &config, mon.checkHealth) mon.pinger = &http.Client{Timeout: config.Timeout} if config.UseGet { mon.method = http.MethodGet @@ -29,7 +29,7 @@ func NewHTTPHealthMonitor(ctx context.Context, name string, url types.URL, confi func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) { req, reqErr := http.NewRequestWithContext( - mon.ctx, + mon.task.Context(), mon.method, mon.URL.String(), nil, diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go index 32a4c6d..56b093d 100644 --- a/internal/watcher/health/monitor.go +++ b/internal/watcher/health/monitor.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/net/types" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -27,7 +28,7 @@ type ( healthy atomic.Bool checkHealth HealthCheckFunc - ctx context.Context + task common.Task cancel context.CancelFunc done chan struct{} @@ -37,22 +38,18 @@ type ( var monMap = F.NewMapOf[string, HealthMonitor]() -func newMonitor(parentCtx context.Context, name string, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { - if parentCtx == nil { - parentCtx = context.Background() - } - ctx, cancel := context.WithCancel(parentCtx) +func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { + task, cancel := task.SubtaskWithCancel("Health monitor for %s", task.Name()) mon := &monitor{ - Name: name, + Name: task.Name(), URL: url.JoinPath(config.Path), Interval: config.Interval, checkHealth: healthCheckFunc, - ctx: ctx, + task: task, cancel: cancel, done: make(chan struct{}), } mon.healthy.Store(true) - monMap.Store(name, mon) return mon } @@ -65,8 +62,12 @@ func IsHealthy(name string) (healthy bool, ok bool) { } func (mon *monitor) Start() { + defer monMap.Store(mon.Name, mon) + defer logger.Debugf("%s health monitor started", mon) + go func() { defer close(mon.done) + defer mon.task.Finished() ok := mon.checkUpdateHealth() if !ok { @@ -78,7 +79,7 @@ func (mon *monitor) Start() { for { select { - case <-mon.ctx.Done(): + case <-mon.task.Context().Done(): return case <-ticker.C: ok = mon.checkUpdateHealth() @@ -92,7 +93,7 @@ func (mon *monitor) Start() { } func (mon *monitor) Stop() { - defer logger.Debugf("health monitor %q stopped", mon) + defer logger.Debugf("%s health monitor stopped", mon) monMap.Delete(mon.Name) diff --git a/internal/watcher/health/raw.go b/internal/watcher/health/raw.go index 4990d87..e3fb447 100644 --- a/internal/watcher/health/raw.go +++ b/internal/watcher/health/raw.go @@ -1,9 +1,9 @@ package health import ( - "context" "net" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/net/types" ) @@ -14,9 +14,9 @@ type ( } ) -func NewRawHealthMonitor(ctx context.Context, name string, url types.URL, config HealthCheckConfig) HealthMonitor { +func NewRawHealthMonitor(task common.Task, url types.URL, config HealthCheckConfig) HealthMonitor { mon := new(RawHealthMonitor) - mon.monitor = newMonitor(ctx, name, url, &config, mon.checkAvail) + mon.monitor = newMonitor(task, url, &config, mon.checkAvail) mon.dialer = &net.Dialer{ Timeout: config.Timeout, FallbackDelay: -1, @@ -25,7 +25,7 @@ func NewRawHealthMonitor(ctx context.Context, name string, url types.URL, config } func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) { - conn, dialErr := mon.dialer.DialContext(mon.ctx, mon.URL.Scheme, mon.URL.Host) + conn, dialErr := mon.dialer.DialContext(mon.task.Context(), mon.URL.Scheme, mon.URL.Host) if dialErr != nil { detail = dialErr.Error() /* trunk-ignore(golangci-lint/nilerr) */