simplify task package implementation

This commit is contained in:
yusing 2025-01-01 06:07:32 +08:00
parent e7aaa95ec5
commit 1ab34ed46f
35 changed files with 547 additions and 600 deletions

View file

@ -23,16 +23,16 @@ lint:
enabled:
- hadolint@2.12.1-beta
- actionlint@1.7.4
- checkov@3.2.334
- checkov@3.2.344
- git-diff-check
- gofmt@1.20.4
- golangci-lint@1.62.2
- osv-scanner@1.9.1
- osv-scanner@1.9.2
- oxipng@9.1.3
- prettier@3.4.2
- shellcheck@0.10.0
- shfmt@3.6.0
- trufflehog@3.86.1
- trufflehog@3.88.0
actions:
disabled:
- trunk-announce

View file

@ -28,10 +28,10 @@ get:
go get -u ./cmd && go mod tidy
debug:
GODOXY_DEBUG=1 make run
GODOXY_DEBUG=1 BUILD_FLAGS="" make run
debug-trace:
GODOXY_DEBUG=1 GODOXY_TRACE=1 run
GODOXY_TRACE=1 make debug
profile:
GODEBUG=gctrace=1 make debug

View file

@ -159,8 +159,7 @@ func main() {
// grafully shutdown
logging.Info().Msg("shutting down")
task.CancelGlobalContext()
_ = task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
_ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown))
}
func prepareDirectory(dir string) {

View file

@ -52,7 +52,7 @@ func List(w http.ResponseWriter, r *http.Request) {
case ListHomepageConfig:
U.RespondJSON(w, r, config.HomepageConfig())
case ListTasks:
U.RespondJSON(w, r, task.DebugTaskMap())
U.RespondJSON(w, r, task.DebugTaskList())
default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
}

View file

@ -153,8 +153,8 @@ func (p *Provider) ScheduleRenewal() {
return
}
go func() {
task := task.GlobalTask("cert renew scheduler")
defer task.Finish("cert renew scheduler stopped")
task := task.RootTask("cert-renew-scheduler", true)
defer task.Finish(nil)
for {
renewalTime := p.ShouldRenewOn()

View file

@ -53,7 +53,7 @@ func newConfig() *Config {
return &Config{
value: types.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](),
task: task.GlobalTask("config"),
task: task.RootTask("config", false),
}
}
@ -76,21 +76,19 @@ func MatchDomains() []string {
}
func WatchChanges() {
task := task.GlobalTask("config watcher")
t := task.RootTask("config_watcher", true)
eventQueue := events.NewEventQueue(
task,
t,
configEventFlushInterval,
OnConfigChange,
func(err E.Error) {
E.LogError("config reload error", err, &logger)
},
)
eventQueue.Start(cfgWatcher.Events(task.Context()))
eventQueue.Start(cfgWatcher.Events(t.Context()))
}
func OnConfigChange(flushTask *task.Task, ev []events.Event) {
defer flushTask.Finish("config reload complete")
func OnConfigChange(ev []events.Event) {
// no matter how many events during the interval
// just reload once and check the last event
switch ev[len(ev)-1].Action {
@ -116,14 +114,14 @@ func Reload() E.Error {
newCfg := newConfig()
err := newCfg.load()
if err != nil {
newCfg.task.Finish(err)
return err
}
// cancel all current subtasks -> wait
// -> replace config -> start new subtasks
instance.task.Finish("config changed")
instance.task.Wait()
*instance = *newCfg
instance = newCfg
instance.StartProxyProviders()
return nil
}
@ -143,8 +141,7 @@ func (cfg *Config) Task() *task.Task {
func (cfg *Config) StartProxyProviders() {
errs := cfg.providers.CollectErrorsParallel(
func(_ string, p *proxy.Provider) error {
subtask := cfg.task.Subtask(p.String())
return p.Start(subtask)
return p.Start(cfg.task)
})
if err := E.Join(errs...); err != nil {
@ -209,9 +206,6 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
}
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
subtask := cfg.task.Subtask("load route providers")
defer subtask.Finish("done")
errs := E.NewBuilder("route provider errors")
results := E.NewBuilder("loaded route providers")

View file

@ -38,7 +38,7 @@ var (
)
func init() {
task.GlobalTask("close docker clients").OnFinished("", func() {
task.OnProgramExit("docker_clients_cleanup", func() {
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()

View file

@ -38,7 +38,7 @@ const (
// TODO: support stream
func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
hcCfg := entry.RawEntry().HealthCheck
hcCfg.Timeout = idleWakerCheckTimeout
@ -46,8 +46,8 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
rp: rp,
stream: stream,
}
watcher, err := registerWatcher(providerSubTask, entry, waker)
task := parent.Subtask("idlewatcher")
watcher, err := registerWatcher(task, entry, waker)
if err != nil {
return nil, E.Errorf("register watcher: %w", err)
}
@ -63,7 +63,7 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
if common.PrometheusEnabled {
m := metrics.GetServiceMetrics()
fqn := providerSubTask.Parent().Name() + "/" + entry.TargetName()
fqn := parent.Name() + "/" + entry.TargetName()
waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn))
waker.metric.Set(float64(watcher.Status()))
}
@ -71,19 +71,18 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
}
// lifetime should follow route provider.
func NewHTTPWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(providerSubTask, entry, rp, nil)
func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(parent, entry, rp, nil)
}
func NewStreamWaker(providerSubTask *task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) {
return newWaker(providerSubTask, entry, nil, stream)
func NewStreamWaker(parent task.Parent, entry route.Entry, stream net.Stream) (Waker, E.Error) {
return newWaker(parent, entry, nil, stream)
}
// Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
routeSubTask.Finish("ignored")
w.task.OnCancel("stop route and cleanup", func() {
routeSubTask.Parent().Finish(w.task.FinishCause())
func (w *Watcher) Start(parent task.Parent) E.Error {
w.task.OnCancel("route_cleanup", func() {
parent.Finish(w.task.FinishCause())
if w.metric != nil {
w.metric.Reset()
}
@ -91,6 +90,11 @@ func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
return nil
}
// Task implements health.HealthMonitor.
func (w *Watcher) Task() *task.Task {
return w.task
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) {
if w.stream != nil {

View file

@ -51,7 +51,7 @@ var (
const dockerReqTimeout = 3 * time.Second
func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
func registerWatcher(watcherTask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 {
@ -67,7 +67,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
w.Config = cfg
w.waker = waker
w.resetIdleTimer()
providerSubtask.Finish("used existing watcher")
watcherTask.Finish("used existing watcher")
return w, nil
}
@ -81,7 +81,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
Config: cfg,
waker: waker,
client: client,
task: providerSubtask,
task: watcherTask,
ticker: time.NewTicker(cfg.IdleTimeout),
}
w.stopByMethod = w.getStopCallback()
@ -210,8 +210,7 @@ func (w *Watcher) resetIdleTimer() {
}
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
eventTask = w.task.Subtask("docker event watcher")
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
eventCh, errCh = dockerWatcher.EventsWithOptions(w.Task().Context(), watcher.DockerListOptions{
Filters: watcher.NewDockerFilter(
watcher.DockerFilterContainer,
watcher.DockerFilterContainerNameID(w.ContainerID),

View file

@ -54,7 +54,7 @@ func SetMiddlewares(mws []map[string]any) error {
return nil
}
func SetAccessLogger(parent *task.Task, cfg *accesslog.Config) (err error) {
func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
epAccessLoggerMu.Lock()
defer epAccessLoggerMu.Unlock()

View file

@ -50,10 +50,12 @@ func Join(errors ...error) Error {
if n == 0 {
return nil
}
errs := make([]error, 0, n)
errs := make([]error, n)
i := 0
for _, err := range errors {
if err != nil {
errs = append(errs, err)
errs[i] = err
i++
}
}
return &nestedError{Extras: errs}

View file

@ -22,6 +22,7 @@ type (
Path string `json:"path" validate:"required"`
Filters Filters `json:"filters"`
Fields Fields `json:"fields"`
// Retention *Retention
}
)
@ -31,7 +32,7 @@ var (
FormatJSON Format = "json"
)
const DefaultBufferSize = 100
const DefaultBufferSize = 64 * 1024 // 64KB
func DefaultConfig() *Config {
return &Config{

View file

@ -0,0 +1,22 @@
package accesslog
import (
"fmt"
"os"
"sync"
"github.com/yusing/go-proxy/internal/task"
)
type File struct {
*os.File
sync.Mutex
}
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return nil, fmt.Errorf("access log open error: %w", err)
}
return NewAccessLogger(parent, &File{File: f}, cfg), nil
}

View file

@ -7,18 +7,17 @@ import (
"net/http"
"net/url"
"strconv"
"time"
)
type (
CommonFormatter struct {
cfg *Fields
}
CombinedFormatter struct {
CommonFormatter
}
JSONFormatter struct {
CommonFormatter
cfg *Fields
GetTimeNow func() time.Time // for testing purposes only
}
CombinedFormatter CommonFormatter
JSONFormatter CommonFormatter
JSONLogEntry struct {
Time string `json:"time"`
IP string `json:"ip"`
@ -39,6 +38,8 @@ type (
}
)
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
func scheme(req *http.Request) string {
if req.TLS != nil {
return "https"
@ -62,7 +63,7 @@ func clientIP(req *http.Request) string {
return req.RemoteAddr
}
func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
line.WriteString(req.Host)
@ -71,7 +72,7 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
line.WriteString(clientIP(req))
line.WriteString(" - - [")
line.WriteString(timeNow())
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
line.WriteString("] \"")
line.WriteString(req.Method)
@ -86,8 +87,8 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
}
func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
f.CommonFormatter.Format(line, req, res)
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
(*CommonFormatter)(f).Format(line, req, res)
line.WriteString(" \"")
line.WriteString(req.Referer())
line.WriteString("\" \"")
@ -95,14 +96,14 @@ func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *ht
line.WriteRune('"')
}
func (f JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
query := f.cfg.Query.ProcessQuery(req.URL.Query())
headers := f.cfg.Headers.ProcessHeaders(req.Header)
headers.Del("Cookie")
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
entry := JSONLogEntry{
Time: timeNow(),
Time: f.GetTimeNow().Format(LogTimeFormat),
IP: clientIP(req),
Method: req.Method,
Scheme: scheme(req),

View file

@ -9,6 +9,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
@ -52,10 +53,13 @@ func New(cfg *Config) *LoadBalancer {
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
lb.startTime = time.Now()
lb.task = routeSubtask
lb.task.OnFinished("loadbalancer cleanup", func() {
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
parent.OnCancel("lb_remove_route", func() {
routes.DeleteHTTPRoute(lb.Link)
})
lb.task.OnFinished("cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v)
@ -66,6 +70,11 @@ func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
return nil
}
// Task implements task.TaskStarter.
func (lb *LoadBalancer) Task() *task.Task {
return lb.task
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason)

View file

@ -32,10 +32,10 @@ func setup() {
return
}
task := task.GlobalTask("error page")
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
t := task.RootTask("error_page", true)
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
loadContent()
go watchDir(task)
go watchDir()
}
func GetStaticFile(filename string) ([]byte, bool) {
@ -73,11 +73,11 @@ func loadContent() {
}
}
func watchDir(task *task.Task) {
eventCh, errCh := dirWatcher.Events(task.Context())
func watchDir() {
eventCh, errCh := dirWatcher.Events(task.RootContext())
for {
select {
case <-task.Context().Done():
case <-task.RootContextCanceled():
return
case event, ok := <-eventCh:
if !ok {

View file

@ -24,8 +24,6 @@ type Server struct {
httpsStarted bool
startTime time.Time
task *task.Task
l zerolog.Logger
}
@ -76,7 +74,6 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider,
http: httpSer,
https: httpsSer,
task: task.GlobalTask(opt.Name + " server"),
l: logger,
}
}
@ -108,7 +105,7 @@ func (s *Server) Start() {
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
}
s.task.OnFinished("stop server", s.stop)
task.OnProgramExit("server."+s.Name+".stop", s.stop)
}
func (s *Server) stop() {
@ -117,12 +114,12 @@ func (s *Server) stop() {
}
if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(s.task.Context()))
s.handleErr("http", s.http.Shutdown(task.RootContext()))
s.httpStarted = false
}
if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(s.task.Context()))
s.handleErr("https", s.https.Shutdown(task.RootContext()))
s.httpsStarted = false
}
}

View file

@ -35,7 +35,7 @@ var (
const dispatchErr = "notification dispatch error"
func StartNotifDispatcher(parent *task.Task) *Dispatcher {
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
dispatcher = &Dispatcher{
task: parent.Subtask("notification"),
logCh: make(chan *LogMessage),

View file

@ -73,19 +73,17 @@ func (r *HTTPRoute) String() string {
return r.TargetName()
}
// Start implements*task.TaskStarter.
func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
// Start implements task.TaskStarter.
func (r *HTTPRoute) Start(parent task.Parent) E.Error {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
}
r.task = providerSubtask
r.task = parent.Subtask("http."+r.TargetName(), false)
switch {
case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName())
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
waker, err := idlewatcher.NewHTTPWaker(r.task, r.ReverseProxyEntry, r.rp)
if err != nil {
r.task.Finish(err)
return err
@ -98,7 +96,7 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
if err == nil {
fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck)
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
r.task.OnCancel("close docker client", client.Close)
r.task.OnCancel("close_docker_client", client.Close)
}
}
if r.HealthMon == nil {
@ -137,29 +135,32 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
}
if r.HealthMon != nil {
healthMonTask := r.task.Subtask("health monitor")
if err := r.HealthMon.Start(healthMonTask); err != nil {
if err := r.HealthMon.Start(r.task); err != nil {
E.LogWarn("health monitor error", err, &r.l)
healthMonTask.Finish(err)
}
}
if entry.UseLoadBalance(r) {
r.addToLoadBalancer()
r.addToLoadBalancer(parent)
} else {
routes.SetHTTPRoute(r.TargetName(), r)
r.task.OnFinished("remove from route table", func() {
r.task.OnFinished("entrypoint_remove_route", func() {
routes.DeleteHTTPRoute(r.TargetName())
})
}
if common.PrometheusEnabled {
r.task.OnFinished("unreg metrics", r.rp.UnregisterMetrics)
r.task.OnFinished("metrics_cleanup", r.rp.UnregisterMetrics)
}
return nil
}
// Finish implements*task.TaskFinisher.
// Task implements task.TaskStarter.
func (r *HTTPRoute) Task() *task.Task {
return r.task
}
// Finish implements task.TaskFinisher.
func (r *HTTPRoute) Finish(reason any) {
r.task.Finish(reason)
}
@ -168,7 +169,7 @@ func (r *HTTPRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.handler.ServeHTTP(w, req)
}
func (r *HTTPRoute) addToLoadBalancer() {
func (r *HTTPRoute) addToLoadBalancer(parent task.Parent) {
var lb *loadbalancer.LoadBalancer
cfg := r.Raw.LoadBalance
l, ok := routes.GetHTTPRoute(cfg.Link)
@ -182,11 +183,7 @@ func (r *HTTPRoute) addToLoadBalancer() {
}
} else {
lb = loadbalancer.New(cfg)
lbTask := r.task.Parent().Subtask("loadbalancer " + cfg.Link)
lbTask.OnCancel("remove lb from routes", func() {
routes.DeleteHTTPRoute(cfg.Link)
})
if err := lb.Start(lbTask); err != nil {
if err := lb.Start(parent); err != nil {
panic(err) // should always return nil
}
linked = &HTTPRoute{
@ -203,9 +200,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
routes.SetHTTPRoute(cfg.Link, linked)
}
r.loadBalancer = lb
r.server = loadbalance.NewServer(r.task.String(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
r.server = loadbalance.NewServer(r.task.Name(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server)
r.task.OnCancel("remove server from lb", func() {
r.task.OnCancel("lb_remove_server", func() {
lb.RemoveServer(r.server)
})
}

View file

@ -28,7 +28,7 @@ func (p *Provider) newEventHandler() *EventHandler {
}
}
func (handler *EventHandler) Handle(parent *task.Task, events []watcher.Event) {
func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) {
oldRoutes := handler.provider.routes
newRoutes, err := handler.provider.loadRoutesImpl()
if err != nil {
@ -97,7 +97,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
return false
}
func (handler *EventHandler) Add(parent *task.Task, route *route.Route) {
func (handler *EventHandler) Add(parent task.Parent, route *route.Route) {
err := handler.provider.startRoute(parent, route)
if err != nil {
handler.errs.Add(err.Subject("add"))
@ -112,7 +112,7 @@ func (handler *EventHandler) Remove(route *route.Route) {
handler.removed.Adds(route.Entry.Alias)
}
func (handler *EventHandler) Update(parent *task.Task, oldRoute *route.Route, newRoute *route.Route) {
func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) {
oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute)
if err != nil {

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"path"
"strings"
"time"
"github.com/rs/zerolog"
@ -60,7 +61,7 @@ func NewFileProvider(filename string) (p *Provider, err error) {
if name == "" {
return nil, ErrEmptyProviderName
}
p = newProvider(name, ProviderTypeFile)
p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile)
p.ProviderImpl, err = FileProviderImpl(filename)
if err != nil {
return nil, err
@ -100,46 +101,43 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) startRoute(parent *task.Task, r *R.Route) E.Error {
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
err := r.Start(subtask)
func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error {
err := r.Start(parent)
if err != nil {
p.routes.Delete(r.Entry.Alias)
subtask.Finish(err) // just to ensure
return err.Subject(r.Entry.Alias)
}
p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() {
r.Task().OnFinished("provider_remove_route", func() {
p.routes.Delete(r.Entry.Alias)
})
return nil
}
// Start implements*task.TaskStarter.
func (p *Provider) Start(configSubtask *task.Task) E.Error {
// routes and event queue will stop on parent cancel
providerTask := configSubtask
func (p *Provider) Start(parent task.Parent) E.Error {
t := parent.Subtask("provider."+p.name, false)
// routes and event queue will stop on config reload
errs := p.routes.CollectErrorsParallel(
func(alias string, r *R.Route) error {
return p.startRoute(providerTask, r)
return p.startRoute(t, r)
})
eventQueue := events.NewEventQueue(
providerTask,
t.Subtask("event_queue", false),
providerEventFlushInterval,
func(flushTask *task.Task, events []events.Event) {
func(events []events.Event) {
handler := p.newEventHandler()
// routes' lifetime should follow the provider's lifetime
handler.Handle(providerTask, events)
handler.Handle(t, events)
handler.Log()
flushTask.Finish("events flushed")
},
func(err E.Error) {
E.LogError("event error", err, p.Logger())
},
)
eventQueue.Start(p.watcher.Events(providerTask.Context()))
eventQueue.Start(p.watcher.Events(t.Context()))
if err := E.Join(errs...); err != nil {
return err.Subject(p.String())

View file

@ -47,20 +47,22 @@ func (r *StreamRoute) String() string {
return "stream " + r.TargetName()
}
// Start implements*task.TaskStarter.
func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
// Start implements task.TaskStarter.
func (r *StreamRoute) Start(parent task.Parent) E.Error {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
}
r.task = providerSubtask
r.task = parent.Subtask("stream." + r.TargetName())
r.Stream = NewStream(r)
parent.OnCancel("finish", func() {
r.task.Finish(nil)
})
switch {
case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName())
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
waker, err := idlewatcher.NewStreamWaker(r.task, r.StreamEntry, r.Stream)
if err != nil {
r.task.Finish(err)
return err
@ -73,7 +75,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
if err == nil {
fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck)
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
r.task.OnCancel("close docker client", client.Close)
r.task.OnCancel("close_docker_client", client.Close)
}
}
if r.HealthMon == nil {
@ -86,7 +88,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
return E.From(err)
}
r.task.OnFinished("close stream", func() {
r.task.OnFinished("close_stream", func() {
if err := r.Stream.Close(); err != nil {
E.LogError("close stream failed", err, &r.l)
}
@ -97,22 +99,26 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
Msg("listening")
if r.HealthMon != nil {
healthMonTask := r.task.Subtask("health monitor")
if err := r.HealthMon.Start(healthMonTask); err != nil {
if err := r.HealthMon.Start(r.task); err != nil {
E.LogWarn("health monitor error", err, &r.l)
healthMonTask.Finish(err)
}
}
go r.acceptConnections()
routes.SetStreamRoute(r.TargetName(), r)
r.task.OnFinished("remove from route table", func() {
r.task.OnFinished("entrypoint_remove_route", func() {
routes.DeleteStreamRoute(r.TargetName())
})
return nil
}
// Task implements task.TaskStarter.
func (r *StreamRoute) Task() *task.Task {
return r.task
}
// Finish implements task.TaskFinisher.
func (r *StreamRoute) Finish(reason any) {
r.task.Finish(reason)
}

View file

@ -95,7 +95,7 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
return fmt.Errorf("unexpected listener type: %T", stream)
}
case io.ReadWriteCloser:
stream.task.OnCancel("close conn", func() { conn.Close() })
stream.task.OnCancel("close_conn", func() { conn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout}
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())

View file

@ -4,355 +4,194 @@ import (
"context"
"errors"
"fmt"
"strings"
"runtime/debug"
"sync"
"time"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var globalTask = createGlobalTask()
func createGlobalTask() (t *Task) {
t = new(Task)
t.name = "root"
t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.subtasks = F.NewSet[*Task]()
return
}
func testResetGlobalTask() {
globalTask = createGlobalTask()
}
type (
TaskStarter interface {
// Start starts the object that implements TaskStarter,
// and returns an error if it fails to start.
//
// The task passed must be a subtask of the caller task.
//
// callerSubtask.Finish must be called when start fails or the object is finished.
Start(callerSubtask *Task) E.Error
Start(parent Parent) E.Error
Task() *Task
}
TaskFinisher interface {
// Finish marks the task as finished and cancel its context.
//
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
// of the task to finish.
//
// Note that it will also cancel all subtasks.
Finish(reason any)
}
// Task controls objects' lifetime.
//
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
//
// When passing a Task object to another function,
// it must be a sub-Task of the current Task,
// in name of "`currentTaskName`Subtask"
//
// Use Task.Finish to stop all subtasks of the Task.
Task struct {
name string
children sync.WaitGroup
onFinished sync.WaitGroup
finished chan struct{}
ctx context.Context
cancel context.CancelCauseFunc
parent *Task
subtasks F.Set[*Task]
subTasksWg sync.WaitGroup
name string
OnFinishedFuncs []func()
OnFinishedMu sync.Mutex
onFinishedWg sync.WaitGroup
finishOnce sync.Once
once sync.Once
}
Parent interface {
Context() context.Context
Subtask(name string, needFinish ...bool) *Task
Name() string
Finish(reason any)
OnCancel(name string, f func())
}
)
var (
ErrProgramExiting = errors.New("program exiting")
ErrTaskCanceled = errors.New("task canceled")
logger = logging.With().Str("module", "task").Logger()
)
// GlobalTask returns a new Task with the given name, derived from the global context.
func GlobalTask(format string, args ...any) *Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return globalTask.Subtask(format)
}
// DebugTaskMap returns a map[string]any representation of the global task tree.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func DebugTaskMap() map[string]any {
return globalTask.serialize()
}
// CancelGlobalContext cancels the global task context, which will cause all tasks
// created to be canceled. This should be called before exiting the program
// to ensure that all tasks are properly cleaned up.
func CancelGlobalContext() {
globalTask.cancel(ErrProgramExiting)
}
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) (err error) {
done := make(chan struct{})
after := time.After(timeout)
go func() {
globalTask.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
return context.DeadlineExceeded
}
}
}
func (t *Task) trace(msg string) {
logger.Trace().Str("name", t.name).Msg(msg)
}
// Name returns the name of the task.
func (t *Task) Name() string {
if !common.IsTrace {
return t.name
}
parts := strings.Split(t.name, " > ")
return parts[len(parts)-1]
}
// String returns the name of the task.
func (t *Task) String() string {
return t.name
}
// Context returns the context associated with the task. This context is
// canceled when Finish of the task is called, or parent task is canceled.
func (t *Task) Context() context.Context {
return t.ctx
}
func (t *Task) Finished() <-chan struct{} {
return t.finished
}
// FinishCause returns the reason / error that caused the task to be finished.
func (t *Task) FinishCause() error {
cause := context.Cause(t.ctx)
if cause == nil {
return t.ctx.Err()
}
return cause
return context.Cause(t.ctx)
}
// Parent returns the parent task of the current task.
func (t *Task) Parent() *Task {
return t.parent
}
func (t *Task) runAllOnFinished(onCompTask *Task) {
<-t.ctx.Done()
t.WaitSubTasks()
for _, OnFinishedFunc := range t.OnFinishedFuncs {
OnFinishedFunc()
t.onFinishedWg.Done()
}
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
}
// OnFinished calls fn when all subtasks are finished.
// OnFinished calls fn when the task is canceled and all subtasks are finished.
//
// It cannot be called after Finish or Wait is called.
// It should not be called after Finish is called.
func (t *Task) OnFinished(about string, fn func()) {
if t.parent == globalTask {
t.OnCancel(about, fn)
return
}
t.onFinishedWg.Add(1)
t.OnFinishedMu.Lock()
defer t.OnFinishedMu.Unlock()
if t.OnFinishedFuncs == nil {
onCompTask := GlobalTask(t.name + " > OnFinished > " + about)
go t.runAllOnFinished(onCompTask)
}
idx := len(t.OnFinishedFuncs)
wrapped := func() {
defer func() {
if err := recover(); err != nil {
logger.Error().
Str("name", t.name).
Interface("err", err).
Msg("panic in " + about)
}
}()
fn()
logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about)
}
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
t.onCancel(about, fn, true)
}
// OnCancel calls fn when the task is canceled.
//
// It cannot be called after Finish or Wait is called.
// It should not be called after Finish is called.
func (t *Task) OnCancel(about string, fn func()) {
onCompTask := GlobalTask(t.name + " > OnFinished")
t.onCancel(about, fn, false)
}
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
t.onFinished.Add(1)
go func() {
<-t.ctx.Done()
fn()
onCompTask.Finish("done")
t.trace("onCancel done: " + about)
if waitSubTasks {
t.children.Wait()
}
t.invokeWithRecover(fn, about)
t.onFinished.Done()
}()
}
// Finish marks the task as finished and cancel its context.
//
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
// of the task to finish.
//
// Note that it will also cancel all subtasks.
// Finish cancel all subtasks and wait for them to finish,
// then marks the task as finished, with the given reason (if any).
func (t *Task) Finish(reason any) {
var format string
switch reason.(type) {
case error:
format = "%w"
case string, fmt.Stringer:
format = "%s"
select {
case <-t.finished:
return
default:
format = "%v"
t.once.Do(func() {
t.finish(reason)
})
}
}
func (t *Task) finish(reason any) {
t.cancel(fmtCause(reason))
t.children.Wait()
t.onFinished.Wait()
if t.finished != nil {
close(t.finished)
}
logger.Trace().Msg("task " + t.name + " finished")
}
func fmtCause(cause any) error {
switch cause := cause.(type) {
case nil:
return nil
case error:
return cause
case string:
return errors.New(cause)
default:
return fmt.Errorf("%v", cause)
}
t.finishOnce.Do(func() {
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
})
t.Wait()
}
// Subtask returns a new subtask with the given name, derived from the parent's context.
//
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
//
// This should not be called after Finish, Wait, or WaitSubTasks is called.
func (t *Task) Subtask(name string) *Task {
ctx, cancel := context.WithCancelCause(t.ctx)
return t.newSubTask(ctx, cancel, name)
}
// This should not be called after Finish is called.
func (t *Task) Subtask(name string, needFinish ...bool) *Task {
nf := len(needFinish) == 0 || needFinish[0]
func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task {
parent := t
if common.IsTrace {
name = parent.name + " > " + name
}
subtask := &Task{
ctx, cancel := context.WithCancelCause(t.ctx)
child := &Task{
finished: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
name: name,
parent: parent,
subtasks: F.NewSet[*Task](),
}
parent.subTasksWg.Add(1)
parent.subtasks.Add(subtask)
if common.IsTrace {
subtask.trace("started")
if t != root {
child.name = t.name + "." + name
allTasks.Add(child)
} else {
child.name = name
}
allTasksWg.Add(1)
t.children.Add(1)
if !nf {
go func() {
subtask.Wait()
subtask.trace("finished: " + subtask.FinishCause().Error())
<-child.ctx.Done()
child.Finish(nil)
}()
}
go func() {
subtask.Wait()
parent.subtasks.Remove(subtask)
parent.subTasksWg.Done()
<-child.finished
allTasksWg.Done()
t.children.Done()
allTasks.Remove(child)
}()
return subtask
logger.Trace().Msg("task " + child.name + " started")
return child
}
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
//
// It must be called only after Finish is called.
func (t *Task) Wait() {
<-t.ctx.Done()
t.WaitSubTasks()
t.onFinishedWg.Wait()
// Name returns the name of the task without parent names.
func (t *Task) Name() string {
parts := strutils.SplitRune(t.name, '.')
return parts[len(parts)-1]
}
// WaitSubTasks waits for all subtasks of the task to finish.
//
// No more subtasks can be added after this call.
//
// It can be called before Finish is called.
func (t *Task) WaitSubTasks() {
t.subTasksWg.Wait()
// String returns the full name of the task.
func (t *Task) String() string {
return t.name
}
// tree returns a string representation of the task tree, with the given
// prefix prepended to each line. The prefix is used to indent the tree,
// and should be a string of spaces or a similar separator.
//
// The resulting string is suitable for printing to the console, and can be
// used to debug the task tree.
//
// The tree is traversed in a depth-first manner, with each task's name and
// line number (if available) printed on a separate line. The line number is
// only printed if the task was created with a non-empty line argument.
//
// The returned string is not guaranteed to be stable, and may change between
// runs of the program. It is intended for debugging purposes only.
func (t *Task) tree(prefix ...string) string {
var sb strings.Builder
var pre string
if len(prefix) > 0 {
pre = prefix[0]
sb.WriteString(pre + "- ")
}
sb.WriteString(t.Name() + "\n")
t.subtasks.RangeAll(func(subtask *Task) {
sb.WriteString(subtask.tree(pre + " "))
})
return sb.String()
// MarshalText implements encoding.TextMarshaler.
func (t *Task) MarshalText() ([]byte, error) {
return []byte(t.name), nil
}
// serialize returns a map[string]any representation of the task tree.
//
// The map contains the following keys:
// - name: the name of the task
// - subtasks: a slice of maps, each representing a subtask
//
// The subtask maps contain the same keys, recursively.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func (t *Task) serialize() map[string]any {
m := make(map[string]any)
parts := strings.Split(t.name, " > ")
m["name"] = parts[len(parts)-1]
if t.subtasks.Size() > 0 {
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
t.subtasks.RangeAll(func(subtask *Task) {
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
})
}
return m
func (t *Task) invokeWithRecover(fn func(), caller string) {
defer func() {
if err := recover(); err != nil {
logger.Error().
Interface("err", err).
Msg("panic in task " + t.name + "." + caller)
if common.IsDebug {
panic(string(debug.Stack()))
}
}
}()
fn()
}

View file

@ -2,132 +2,112 @@ package task
import (
"context"
"sync/atomic"
"sync"
"testing"
"time"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
const (
rootTaskName = "root-task"
subTaskName = "subtask"
)
func TestTaskCreation(t *testing.T) {
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
ExpectEqual(t, rootTaskName, rootTask.Name())
ExpectEqual(t, subTaskName, subTask.Name())
func testTask() *Task {
return RootTask("test", false)
}
func TestTaskCancellation(t *testing.T) {
subTaskDone := make(chan struct{})
func TestChildTaskCancellation(t *testing.T) {
t.Cleanup(testCleanup)
rootTask := GlobalTask(rootTaskName)
subTask := rootTask.Subtask(subTaskName)
parent := testTask()
child := parent.Subtask("")
go func() {
subTask.Wait()
close(subTaskDone)
defer child.Finish(nil)
for {
select {
case <-child.Context().Done():
return
default:
continue
}
}
}()
go rootTask.Finish(nil)
parent.cancel(nil) // should also cancel child
select {
case <-subTaskDone:
err := subTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(subTask.Context())
ExpectError(t, ErrTaskCanceled, cause)
case <-time.After(1 * time.Second):
case <-child.Finished():
ExpectError(t, context.Canceled, child.Context().Err())
default:
t.Fatal("subTask context was not canceled as expected")
}
}
func TestOnComplete(t *testing.T) {
rootTask := GlobalTask(rootTaskName)
task := rootTask.Subtask(subTaskName)
func TestTaskOnCancelOnFinished(t *testing.T) {
t.Cleanup(testCleanup)
task := testTask()
var value atomic.Int32
task.OnFinished("set value", func() {
value.Store(1234)
var shouldTrueOnCancel bool
var shouldTrueOnFinish bool
task.OnCancel("", func() {
shouldTrueOnCancel = true
})
task.OnFinished("", func() {
shouldTrueOnFinish = true
})
ExpectFalse(t, shouldTrueOnFinish)
task.Finish(nil)
ExpectEqual(t, value.Load(), 1234)
ExpectTrue(t, shouldTrueOnCancel)
ExpectTrue(t, shouldTrueOnFinish)
}
func TestGlobalContextWait(t *testing.T) {
testResetGlobalTask()
defer CancelGlobalContext()
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
t.Cleanup(testCleanup)
task := testTask()
rootTask := GlobalTask(rootTaskName)
finished := false
finished1, finished2 := false, false
subTask1 := rootTask.Subtask(subTaskName)
subTask2 := rootTask.Subtask(subTaskName)
subTask1.OnFinished("", func() {
finished1 = true
})
subTask2.OnFinished("", func() {
finished2 = true
task.OnFinished("", func() {
finished = true
})
go func() {
time.Sleep(500 * time.Millisecond)
subTask1.Finish(nil)
defer task.Finish(nil)
for {
select {
case <-task.Context().Done():
return
default:
continue
}
}
}()
go func() {
time.Sleep(500 * time.Millisecond)
subTask2.Finish(nil)
}()
ExpectNoError(t, GracefulShutdown(1*time.Second))
ExpectTrue(t, finished)
go func() {
subTask1.Wait()
subTask2.Wait()
rootTask.Finish(nil)
}()
_ = GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
<-root.finished
ExpectError(t, context.Canceled, task.Context().Err())
ExpectError(t, ErrProgramExiting, task.FinishCause())
}
func TestTimeoutOnGlobalContextWait(t *testing.T) {
testResetGlobalTask()
func TestTimeoutOnGracefulShutdown(t *testing.T) {
t.Cleanup(testCleanup)
_ = testTask()
rootTask := GlobalTask(rootTaskName)
rootTask.Subtask(subTaskName)
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
}
func TestGlobalContextCancellation(t *testing.T) {
testResetGlobalTask()
taskDone := make(chan struct{})
rootTask := GlobalTask(rootTaskName)
go func() {
rootTask.Wait()
close(taskDone)
}()
CancelGlobalContext()
select {
case <-taskDone:
err := rootTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(rootTask.Context())
ExpectError(t, ErrProgramExiting, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
func TestFinishMultipleCalls(t *testing.T) {
t.Cleanup(testCleanup)
task := testTask()
var wg sync.WaitGroup
wg.Add(5)
for range 5 {
go func() {
defer wg.Done()
task.Finish(nil)
}()
}
wg.Wait()
}

96
internal/task/utils.go Normal file
View file

@ -0,0 +1,96 @@
package task
import (
"context"
"encoding/json"
"errors"
"slices"
"sync"
"time"
"github.com/yusing/go-proxy/internal/logging"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
var ErrProgramExiting = errors.New("program exiting")
var logger = logging.With().Str("module", "task").Logger()
var root = newRoot()
var allTasks = F.NewSet[*Task]()
var allTasksWg sync.WaitGroup
func testCleanup() {
root = newRoot()
allTasks.Clear()
allTasksWg = sync.WaitGroup{}
}
// RootTask returns a new Task with the given name, derived from the root context.
func RootTask(name string, needFinish bool) *Task {
return root.Subtask(name, needFinish)
}
func newRoot() *Task {
t := &Task{name: "root"}
t.ctx, t.cancel = context.WithCancelCause(context.Background())
return t
}
func RootContext() context.Context {
return root.ctx
}
func RootContextCanceled() <-chan struct{} {
return root.ctx.Done()
}
func OnProgramExit(about string, fn func()) {
root.OnFinished(about, fn)
}
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GracefulShutdown(timeout time.Duration) (err error) {
root.cancel(ErrProgramExiting)
done := make(chan struct{})
after := time.After(timeout)
go func() {
allTasksWg.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
b, err := json.Marshal(DebugTaskList())
if err != nil {
logger.Warn().Err(err).Msg("failed to marshal tasks")
return context.DeadlineExceeded
}
logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
return context.DeadlineExceeded
}
}
}
// DebugTaskList returns list of all tasks.
//
// The returned string is suitable for printing to the console.
func DebugTaskList() []string {
l := make([]string, 0, allTasks.Size())
allTasks.RangeAll(func(t *Task) {
l = append(l, t.name)
})
slices.Sort(l)
return l
}

View file

@ -1,30 +0,0 @@
package utils
import (
"encoding/json"
"sync/atomic"
)
type AtomicValue[T any] struct {
atomic.Value
}
func (a *AtomicValue[T]) Load() T {
return a.Value.Load().(T)
}
func (a *AtomicValue[T]) Store(v T) {
a.Value.Store(v)
}
func (a *AtomicValue[T]) Swap(v T) T {
return a.Value.Swap(v).(T)
}
func (a *AtomicValue[T]) CompareAndSwap(oldV, newV T) bool {
return a.Value.CompareAndSwap(oldV, newV)
}
func (a *AtomicValue[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.Load())
}

View file

@ -0,0 +1,30 @@
package atomic
import (
"encoding/json"
"sync/atomic"
)
type Value[T any] struct {
atomic.Value
}
func (a *Value[T]) Load() T {
return a.Value.Load().(T)
}
func (a *Value[T]) Store(v T) {
a.Value.Store(v)
}
func (a *Value[T]) Swap(v T) T {
return a.Value.Swap(v).(T)
}
func (a *Value[T]) CompareAndSwap(oldV, newV T) bool {
return a.Value.CompareAndSwap(oldV, newV)
}
func (a *Value[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.Load())
}

View file

@ -152,9 +152,10 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
return m.CollectErrors(do)
}
errs := make([]error, 0)
mu := sync.Mutex{}
wg := sync.WaitGroup{}
var errs []error
var mu sync.Mutex
var wg sync.WaitGroup
m.Range(func(k KT, v VT) bool {
wg.Add(1)
go func() {
@ -171,24 +172,6 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
return errs
}
// RemoveAll removes all key-value pairs from the map where the value matches the given criteria.
//
// Parameters:
//
// criteria: function to determine whether a value should be removed
//
// Returns:
//
// nothing
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.Range(func(k KT, v VT) bool {
if criteria(v) {
m.Delete(k)
}
return true
})
}
func (m Map[KT, VT]) Has(k KT) bool {
_, ok := m.Load(k)
return ok

View file

@ -61,3 +61,7 @@ func (set Set[T]) RangeAllParallel(f func(T)) {
func (set Set[T]) Size() int {
return set.m.Size()
}
func (set Set[T]) IsEmpty() bool {
return set.m == nil || set.m.Size() == 0
}

View file

@ -20,10 +20,17 @@ func IgnoreError[Result any](r Result, _ error) Result {
return r
}
func fmtError(err error) string {
if err == nil {
return "<nil>"
}
return ansi.StripANSI(err.Error())
}
func ExpectNoError(t *testing.T, err error) {
t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() {
t.Errorf("expected err=nil, got %s", ansi.StripANSI(err.Error()))
if err != nil {
t.Errorf("expected err=nil, got %s", fmtError(err))
t.FailNow()
}
}
@ -31,7 +38,7 @@ func ExpectNoError(t *testing.T, err error) {
func ExpectError(t *testing.T, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("expected err %s, got %s", expected, ansi.StripANSI(err.Error()))
t.Errorf("expected err %s, got %s", expected, fmtError(err))
t.FailNow()
}
}
@ -39,7 +46,7 @@ func ExpectError(t *testing.T, expected error, err error) {
func ExpectError2(t *testing.T, input any, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("%v: expected err %s, got %s", input, expected, ansi.StripANSI(err.Error()))
t.Errorf("%v: expected err %s, got %s", input, expected, fmtError(err))
t.FailNow()
}
}
@ -48,7 +55,7 @@ func ExpectErrorT[T error](t *testing.T, err error) {
t.Helper()
var errAs T
if !errors.As(err, &errAs) {
t.Errorf("expected err %T, got %s", errAs, ansi.StripANSI(err.Error()))
t.Errorf("expected err %T, got %s", errAs, fmtError(err))
t.FailNow()
}
}

View file

@ -16,8 +16,10 @@ var (
func NewConfigFileWatcher(filename string) Watcher {
configDirWatcherMu.Lock()
defer configDirWatcherMu.Unlock()
if configDirWatcher == nil {
configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath)
t := task.RootTask("config_dir_watcher", false)
configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath)
}
return configDirWatcher.Add(filename)
}

View file

@ -37,7 +37,7 @@ type DirWatcher struct {
//
// Note that the returned DirWatcher is not ready to use until the goroutine
// started by NewDirectoryWatcher has finished.
func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
//! subdirectories are not watched
w, err := fsnotify.NewWatcher()
if err != nil {
@ -56,7 +56,7 @@ func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
fwMap: F.NewMapOf[string, *fileWatcher](),
eventCh: make(chan Event),
errCh: make(chan E.Error),
task: callerSubtask,
task: parent.Subtask("dir_watcher(" + dirPath + ")"),
}
go helper.start()
return helper
@ -80,17 +80,19 @@ func (h *DirWatcher) Add(relPath string) Watcher {
eventCh: make(chan Event),
errCh: make(chan E.Error),
}
h.task.OnFinished("close file watcher for "+relPath, func() {
close(s.eventCh)
close(s.errCh)
})
h.fwMap.Store(relPath, s)
return s
}
func (h *DirWatcher) cleanup() {
h.w.Close()
close(h.eventCh)
close(h.errCh)
h.task.Finish(nil)
}
func (h *DirWatcher) start() {
defer close(h.eventCh)
defer h.w.Close()
defer h.cleanup()
for {
select {

View file

@ -1,6 +1,7 @@
package events
import (
"runtime/debug"
"time"
"github.com/yusing/go-proxy/internal/common"
@ -17,7 +18,7 @@ type (
onFlush OnFlushFunc
onError OnErrorFunc
}
OnFlushFunc = func(flushTask *task.Task, events []Event)
OnFlushFunc = func(events []Event)
OnErrorFunc = func(err E.Error)
)
@ -38,9 +39,9 @@ const eventQueueCapacity = 10
// but the onFlush function can return earlier (e.g. run in another goroutine).
//
// If task is canceled before the flushInterval is reached, the events in queue will be discarded.
func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
func NewEventQueue(queueTask *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
return &EventQueue{
task: parent.Subtask("event queue"),
task: queueTask,
queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval),
flushInterval: flushInterval,
@ -50,19 +51,20 @@ func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlu
}
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
if common.IsProduction {
origOnFlush := e.onFlush
// recover panic in onFlush when in production mode
e.onFlush = func(flushTask *task.Task, events []Event) {
defer func() {
if err := recover(); err != nil {
e.onError(E.New("recovered panic in onFlush").
Withf("%v", err).
Subject(e.task.Parent().String()))
origOnFlush := e.onFlush
// recover panic in onFlush when in production mode
e.onFlush = func(events []Event) {
defer func() {
if err := recover(); err != nil {
e.onError(E.New("recovered panic in onFlush").
Withf("%v", err).
Subject(e.task.Name()))
if common.IsDebug {
panic(string(debug.Stack()))
}
}()
origOnFlush(flushTask, events)
}
}
}()
origOnFlush(events)
}
go func() {
@ -75,19 +77,24 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
return
case <-e.ticker.C:
if len(e.queue) > 0 {
flushTask := e.task.Subtask("flush events")
queue := e.queue
e.queue = make([]Event, 0, eventQueueCapacity)
go e.onFlush(flushTask, queue)
flushTask.Wait()
// clone -> clear -> flush
queue := make([]Event, len(e.queue))
copy(queue, e.queue)
e.queue = e.queue[:0]
e.onFlush(queue)
}
e.ticker.Reset(e.flushInterval)
case event, ok := <-eventCh:
e.queue = append(e.queue, event)
if !ok {
return
}
case err := <-errCh:
e.queue = append(e.queue, event)
case err, ok := <-errCh:
if !ok {
return
}
if err != nil {
e.onError(err)
}
@ -95,10 +102,3 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
}
}()
}
// Wait waits for all events to be flushed and the task to finish.
//
// It is safe to call this method multiple times.
func (e *EventQueue) Wait() {
e.task.Wait()
}

View file

@ -13,7 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -23,9 +23,9 @@ type (
monitor struct {
service string
config *health.HealthCheckConfig
url U.AtomicValue[types.URL]
url atomic.Value[types.URL]
status U.AtomicValue[health.Status]
status atomic.Value[health.Status]
lastResult *health.HealthCheckResult
lastSeen time.Time
@ -59,10 +59,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
}
// Start implements task.TaskStarter.
func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
mon.service = routeSubtask.Parent().Name()
mon.task = routeSubtask
func (mon *monitor) Start(parent task.Parent) E.Error {
if mon.config.Interval <= 0 {
return E.From(ErrNegativeInterval)
}
@ -71,6 +68,9 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service))
}
mon.service = parent.Name()
mon.task = parent.Subtask("health_monitor")
go func() {
logger := logging.With().Str("name", mon.service).Logger()
@ -78,10 +78,10 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
if mon.status.Load() != health.StatusError {
mon.status.Store(health.StatusUnknown)
}
mon.task.Finish(nil)
if mon.metric != nil {
mon.metric.Reset()
}
mon.task.Finish(nil)
}()
if err := mon.checkUpdateHealth(); err != nil {
@ -108,6 +108,11 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
return nil
}
// Task implements task.TaskStarter.
func (mon *monitor) Task() *task.Task {
return mon.task
}
// Finish implements task.TaskFinisher.
func (mon *monitor) Finish(reason any) {
mon.task.Finish(reason)