mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 04:42:33 +02:00
simplify task package implementation
This commit is contained in:
parent
e7aaa95ec5
commit
1ab34ed46f
35 changed files with 547 additions and 600 deletions
|
@ -23,16 +23,16 @@ lint:
|
|||
enabled:
|
||||
- hadolint@2.12.1-beta
|
||||
- actionlint@1.7.4
|
||||
- checkov@3.2.334
|
||||
- checkov@3.2.344
|
||||
- git-diff-check
|
||||
- gofmt@1.20.4
|
||||
- golangci-lint@1.62.2
|
||||
- osv-scanner@1.9.1
|
||||
- osv-scanner@1.9.2
|
||||
- oxipng@9.1.3
|
||||
- prettier@3.4.2
|
||||
- shellcheck@0.10.0
|
||||
- shfmt@3.6.0
|
||||
- trufflehog@3.86.1
|
||||
- trufflehog@3.88.0
|
||||
actions:
|
||||
disabled:
|
||||
- trunk-announce
|
||||
|
|
4
Makefile
4
Makefile
|
@ -28,10 +28,10 @@ get:
|
|||
go get -u ./cmd && go mod tidy
|
||||
|
||||
debug:
|
||||
GODOXY_DEBUG=1 make run
|
||||
GODOXY_DEBUG=1 BUILD_FLAGS="" make run
|
||||
|
||||
debug-trace:
|
||||
GODOXY_DEBUG=1 GODOXY_TRACE=1 run
|
||||
GODOXY_TRACE=1 make debug
|
||||
|
||||
profile:
|
||||
GODEBUG=gctrace=1 make debug
|
||||
|
|
|
@ -159,8 +159,7 @@ func main() {
|
|||
|
||||
// grafully shutdown
|
||||
logging.Info().Msg("shutting down")
|
||||
task.CancelGlobalContext()
|
||||
_ = task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
||||
_ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
||||
}
|
||||
|
||||
func prepareDirectory(dir string) {
|
||||
|
|
|
@ -52,7 +52,7 @@ func List(w http.ResponseWriter, r *http.Request) {
|
|||
case ListHomepageConfig:
|
||||
U.RespondJSON(w, r, config.HomepageConfig())
|
||||
case ListTasks:
|
||||
U.RespondJSON(w, r, task.DebugTaskMap())
|
||||
U.RespondJSON(w, r, task.DebugTaskList())
|
||||
default:
|
||||
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
||||
}
|
||||
|
|
|
@ -153,8 +153,8 @@ func (p *Provider) ScheduleRenewal() {
|
|||
return
|
||||
}
|
||||
go func() {
|
||||
task := task.GlobalTask("cert renew scheduler")
|
||||
defer task.Finish("cert renew scheduler stopped")
|
||||
task := task.RootTask("cert-renew-scheduler", true)
|
||||
defer task.Finish(nil)
|
||||
|
||||
for {
|
||||
renewalTime := p.ShouldRenewOn()
|
||||
|
|
|
@ -53,7 +53,7 @@ func newConfig() *Config {
|
|||
return &Config{
|
||||
value: types.DefaultConfig(),
|
||||
providers: F.NewMapOf[string, *proxy.Provider](),
|
||||
task: task.GlobalTask("config"),
|
||||
task: task.RootTask("config", false),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,21 +76,19 @@ func MatchDomains() []string {
|
|||
}
|
||||
|
||||
func WatchChanges() {
|
||||
task := task.GlobalTask("config watcher")
|
||||
t := task.RootTask("config_watcher", true)
|
||||
eventQueue := events.NewEventQueue(
|
||||
task,
|
||||
t,
|
||||
configEventFlushInterval,
|
||||
OnConfigChange,
|
||||
func(err E.Error) {
|
||||
E.LogError("config reload error", err, &logger)
|
||||
},
|
||||
)
|
||||
eventQueue.Start(cfgWatcher.Events(task.Context()))
|
||||
eventQueue.Start(cfgWatcher.Events(t.Context()))
|
||||
}
|
||||
|
||||
func OnConfigChange(flushTask *task.Task, ev []events.Event) {
|
||||
defer flushTask.Finish("config reload complete")
|
||||
|
||||
func OnConfigChange(ev []events.Event) {
|
||||
// no matter how many events during the interval
|
||||
// just reload once and check the last event
|
||||
switch ev[len(ev)-1].Action {
|
||||
|
@ -116,14 +114,14 @@ func Reload() E.Error {
|
|||
newCfg := newConfig()
|
||||
err := newCfg.load()
|
||||
if err != nil {
|
||||
newCfg.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// cancel all current subtasks -> wait
|
||||
// -> replace config -> start new subtasks
|
||||
instance.task.Finish("config changed")
|
||||
instance.task.Wait()
|
||||
*instance = *newCfg
|
||||
instance = newCfg
|
||||
instance.StartProxyProviders()
|
||||
return nil
|
||||
}
|
||||
|
@ -143,8 +141,7 @@ func (cfg *Config) Task() *task.Task {
|
|||
func (cfg *Config) StartProxyProviders() {
|
||||
errs := cfg.providers.CollectErrorsParallel(
|
||||
func(_ string, p *proxy.Provider) error {
|
||||
subtask := cfg.task.Subtask(p.String())
|
||||
return p.Start(subtask)
|
||||
return p.Start(cfg.task)
|
||||
})
|
||||
|
||||
if err := E.Join(errs...); err != nil {
|
||||
|
@ -209,9 +206,6 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
|
|||
}
|
||||
|
||||
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
|
||||
subtask := cfg.task.Subtask("load route providers")
|
||||
defer subtask.Finish("done")
|
||||
|
||||
errs := E.NewBuilder("route provider errors")
|
||||
results := E.NewBuilder("loaded route providers")
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
task.GlobalTask("close docker clients").OnFinished("", func() {
|
||||
task.OnProgramExit("docker_clients_cleanup", func() {
|
||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
|
|
|
@ -38,7 +38,7 @@ const (
|
|||
|
||||
// TODO: support stream
|
||||
|
||||
func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
|
||||
func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
|
||||
hcCfg := entry.RawEntry().HealthCheck
|
||||
hcCfg.Timeout = idleWakerCheckTimeout
|
||||
|
||||
|
@ -46,8 +46,8 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
|
|||
rp: rp,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
watcher, err := registerWatcher(providerSubTask, entry, waker)
|
||||
task := parent.Subtask("idlewatcher")
|
||||
watcher, err := registerWatcher(task, entry, waker)
|
||||
if err != nil {
|
||||
return nil, E.Errorf("register watcher: %w", err)
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
|
|||
|
||||
if common.PrometheusEnabled {
|
||||
m := metrics.GetServiceMetrics()
|
||||
fqn := providerSubTask.Parent().Name() + "/" + entry.TargetName()
|
||||
fqn := parent.Name() + "/" + entry.TargetName()
|
||||
waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn))
|
||||
waker.metric.Set(float64(watcher.Status()))
|
||||
}
|
||||
|
@ -71,19 +71,18 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
|
|||
}
|
||||
|
||||
// lifetime should follow route provider.
|
||||
func NewHTTPWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, rp, nil)
|
||||
func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
||||
return newWaker(parent, entry, rp, nil)
|
||||
}
|
||||
|
||||
func NewStreamWaker(providerSubTask *task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, nil, stream)
|
||||
func NewStreamWaker(parent task.Parent, entry route.Entry, stream net.Stream) (Waker, E.Error) {
|
||||
return newWaker(parent, entry, nil, stream)
|
||||
}
|
||||
|
||||
// Start implements health.HealthMonitor.
|
||||
func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
|
||||
routeSubTask.Finish("ignored")
|
||||
w.task.OnCancel("stop route and cleanup", func() {
|
||||
routeSubTask.Parent().Finish(w.task.FinishCause())
|
||||
func (w *Watcher) Start(parent task.Parent) E.Error {
|
||||
w.task.OnCancel("route_cleanup", func() {
|
||||
parent.Finish(w.task.FinishCause())
|
||||
if w.metric != nil {
|
||||
w.metric.Reset()
|
||||
}
|
||||
|
@ -91,6 +90,11 @@ func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Task implements health.HealthMonitor.
|
||||
func (w *Watcher) Task() *task.Task {
|
||||
return w.task
|
||||
}
|
||||
|
||||
// Finish implements health.HealthMonitor.
|
||||
func (w *Watcher) Finish(reason any) {
|
||||
if w.stream != nil {
|
||||
|
|
|
@ -51,7 +51,7 @@ var (
|
|||
|
||||
const dockerReqTimeout = 3 * time.Second
|
||||
|
||||
func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
|
||||
func registerWatcher(watcherTask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
|
||||
cfg := entry.IdlewatcherConfig()
|
||||
|
||||
if cfg.IdleTimeout == 0 {
|
||||
|
@ -67,7 +67,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
|
|||
w.Config = cfg
|
||||
w.waker = waker
|
||||
w.resetIdleTimer()
|
||||
providerSubtask.Finish("used existing watcher")
|
||||
watcherTask.Finish("used existing watcher")
|
||||
return w, nil
|
||||
}
|
||||
|
||||
|
@ -81,7 +81,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
|
|||
Config: cfg,
|
||||
waker: waker,
|
||||
client: client,
|
||||
task: providerSubtask,
|
||||
task: watcherTask,
|
||||
ticker: time.NewTicker(cfg.IdleTimeout),
|
||||
}
|
||||
w.stopByMethod = w.getStopCallback()
|
||||
|
@ -210,8 +210,7 @@ func (w *Watcher) resetIdleTimer() {
|
|||
}
|
||||
|
||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
|
||||
eventTask = w.task.Subtask("docker event watcher")
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(w.Task().Context(), watcher.DockerListOptions{
|
||||
Filters: watcher.NewDockerFilter(
|
||||
watcher.DockerFilterContainer,
|
||||
watcher.DockerFilterContainerNameID(w.ContainerID),
|
||||
|
|
|
@ -54,7 +54,7 @@ func SetMiddlewares(mws []map[string]any) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func SetAccessLogger(parent *task.Task, cfg *accesslog.Config) (err error) {
|
||||
func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
|
||||
epAccessLoggerMu.Lock()
|
||||
defer epAccessLoggerMu.Unlock()
|
||||
|
||||
|
|
|
@ -50,10 +50,12 @@ func Join(errors ...error) Error {
|
|||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
errs := make([]error, 0, n)
|
||||
errs := make([]error, n)
|
||||
i := 0
|
||||
for _, err := range errors {
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
errs[i] = err
|
||||
i++
|
||||
}
|
||||
}
|
||||
return &nestedError{Extras: errs}
|
||||
|
|
|
@ -22,6 +22,7 @@ type (
|
|||
Path string `json:"path" validate:"required"`
|
||||
Filters Filters `json:"filters"`
|
||||
Fields Fields `json:"fields"`
|
||||
// Retention *Retention
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -31,7 +32,7 @@ var (
|
|||
FormatJSON Format = "json"
|
||||
)
|
||||
|
||||
const DefaultBufferSize = 100
|
||||
const DefaultBufferSize = 64 * 1024 // 64KB
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
|
|
22
internal/net/http/accesslog/file_logger.go
Normal file
22
internal/net/http/accesslog/file_logger.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
*os.File
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
|
||||
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access log open error: %w", err)
|
||||
}
|
||||
return NewAccessLogger(parent, &File{File: f}, cfg), nil
|
||||
}
|
|
@ -7,18 +7,17 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
CommonFormatter struct {
|
||||
cfg *Fields
|
||||
GetTimeNow func() time.Time // for testing purposes only
|
||||
}
|
||||
CombinedFormatter struct {
|
||||
CommonFormatter
|
||||
}
|
||||
JSONFormatter struct {
|
||||
CommonFormatter
|
||||
}
|
||||
CombinedFormatter CommonFormatter
|
||||
JSONFormatter CommonFormatter
|
||||
|
||||
JSONLogEntry struct {
|
||||
Time string `json:"time"`
|
||||
IP string `json:"ip"`
|
||||
|
@ -39,6 +38,8 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
|
||||
func scheme(req *http.Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
|
@ -62,7 +63,7 @@ func clientIP(req *http.Request) string {
|
|||
return req.RemoteAddr
|
||||
}
|
||||
|
||||
func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
||||
|
||||
line.WriteString(req.Host)
|
||||
|
@ -71,7 +72,7 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
|
|||
line.WriteString(clientIP(req))
|
||||
line.WriteString(" - - [")
|
||||
|
||||
line.WriteString(timeNow())
|
||||
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
|
||||
line.WriteString("] \"")
|
||||
|
||||
line.WriteString(req.Method)
|
||||
|
@ -86,8 +87,8 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
|
|||
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
|
||||
}
|
||||
|
||||
func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
f.CommonFormatter.Format(line, req, res)
|
||||
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
(*CommonFormatter)(f).Format(line, req, res)
|
||||
line.WriteString(" \"")
|
||||
line.WriteString(req.Referer())
|
||||
line.WriteString("\" \"")
|
||||
|
@ -95,14 +96,14 @@ func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *ht
|
|||
line.WriteRune('"')
|
||||
}
|
||||
|
||||
func (f JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
||||
headers := f.cfg.Headers.ProcessHeaders(req.Header)
|
||||
headers.Del("Cookie")
|
||||
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
|
||||
|
||||
entry := JSONLogEntry{
|
||||
Time: timeNow(),
|
||||
Time: f.GetTimeNow().Format(LogTimeFormat),
|
||||
IP: clientIP(req),
|
||||
Method: req.Method,
|
||||
Scheme: scheme(req),
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||
|
@ -52,10 +53,13 @@ func New(cfg *Config) *LoadBalancer {
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
|
||||
func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = routeSubtask
|
||||
lb.task.OnFinished("loadbalancer cleanup", func() {
|
||||
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
|
||||
parent.OnCancel("lb_remove_route", func() {
|
||||
routes.DeleteHTTPRoute(lb.Link)
|
||||
})
|
||||
lb.task.OnFinished("cleanup", func() {
|
||||
if lb.impl != nil {
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
lb.impl.OnRemoveServer(v)
|
||||
|
@ -66,6 +70,11 @@ func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Task implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Task() *task.Task {
|
||||
return lb.task
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (lb *LoadBalancer) Finish(reason any) {
|
||||
lb.task.Finish(reason)
|
||||
|
|
|
@ -32,10 +32,10 @@ func setup() {
|
|||
return
|
||||
}
|
||||
|
||||
task := task.GlobalTask("error page")
|
||||
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
|
||||
t := task.RootTask("error_page", true)
|
||||
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
|
||||
loadContent()
|
||||
go watchDir(task)
|
||||
go watchDir()
|
||||
}
|
||||
|
||||
func GetStaticFile(filename string) ([]byte, bool) {
|
||||
|
@ -73,11 +73,11 @@ func loadContent() {
|
|||
}
|
||||
}
|
||||
|
||||
func watchDir(task *task.Task) {
|
||||
eventCh, errCh := dirWatcher.Events(task.Context())
|
||||
func watchDir() {
|
||||
eventCh, errCh := dirWatcher.Events(task.RootContext())
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
case <-task.RootContextCanceled():
|
||||
return
|
||||
case event, ok := <-eventCh:
|
||||
if !ok {
|
||||
|
|
|
@ -24,8 +24,6 @@ type Server struct {
|
|||
httpsStarted bool
|
||||
startTime time.Time
|
||||
|
||||
task *task.Task
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
|
||||
|
@ -76,7 +74,6 @@ func NewServer(opt Options) (s *Server) {
|
|||
CertProvider: opt.CertProvider,
|
||||
http: httpSer,
|
||||
https: httpsSer,
|
||||
task: task.GlobalTask(opt.Name + " server"),
|
||||
l: logger,
|
||||
}
|
||||
}
|
||||
|
@ -108,7 +105,7 @@ func (s *Server) Start() {
|
|||
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
|
||||
}
|
||||
|
||||
s.task.OnFinished("stop server", s.stop)
|
||||
task.OnProgramExit("server."+s.Name+".stop", s.stop)
|
||||
}
|
||||
|
||||
func (s *Server) stop() {
|
||||
|
@ -117,12 +114,12 @@ func (s *Server) stop() {
|
|||
}
|
||||
|
||||
if s.http != nil && s.httpStarted {
|
||||
s.handleErr("http", s.http.Shutdown(s.task.Context()))
|
||||
s.handleErr("http", s.http.Shutdown(task.RootContext()))
|
||||
s.httpStarted = false
|
||||
}
|
||||
|
||||
if s.https != nil && s.httpsStarted {
|
||||
s.handleErr("https", s.https.Shutdown(s.task.Context()))
|
||||
s.handleErr("https", s.https.Shutdown(task.RootContext()))
|
||||
s.httpsStarted = false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ var (
|
|||
|
||||
const dispatchErr = "notification dispatch error"
|
||||
|
||||
func StartNotifDispatcher(parent *task.Task) *Dispatcher {
|
||||
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
|
||||
dispatcher = &Dispatcher{
|
||||
task: parent.Subtask("notification"),
|
||||
logCh: make(chan *LogMessage),
|
||||
|
|
|
@ -73,19 +73,17 @@ func (r *HTTPRoute) String() string {
|
|||
return r.TargetName()
|
||||
}
|
||||
|
||||
// Start implements*task.TaskStarter.
|
||||
func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *HTTPRoute) Start(parent task.Parent) E.Error {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
}
|
||||
|
||||
r.task = providerSubtask
|
||||
r.task = parent.Subtask("http."+r.TargetName(), false)
|
||||
|
||||
switch {
|
||||
case entry.UseIdleWatcher(r):
|
||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName())
|
||||
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
|
||||
waker, err := idlewatcher.NewHTTPWaker(r.task, r.ReverseProxyEntry, r.rp)
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
|
@ -98,7 +96,7 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
|
|||
if err == nil {
|
||||
fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck)
|
||||
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
|
||||
r.task.OnCancel("close docker client", client.Close)
|
||||
r.task.OnCancel("close_docker_client", client.Close)
|
||||
}
|
||||
}
|
||||
if r.HealthMon == nil {
|
||||
|
@ -137,29 +135,32 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
|
|||
}
|
||||
|
||||
if r.HealthMon != nil {
|
||||
healthMonTask := r.task.Subtask("health monitor")
|
||||
if err := r.HealthMon.Start(healthMonTask); err != nil {
|
||||
if err := r.HealthMon.Start(r.task); err != nil {
|
||||
E.LogWarn("health monitor error", err, &r.l)
|
||||
healthMonTask.Finish(err)
|
||||
}
|
||||
}
|
||||
|
||||
if entry.UseLoadBalance(r) {
|
||||
r.addToLoadBalancer()
|
||||
r.addToLoadBalancer(parent)
|
||||
} else {
|
||||
routes.SetHTTPRoute(r.TargetName(), r)
|
||||
r.task.OnFinished("remove from route table", func() {
|
||||
r.task.OnFinished("entrypoint_remove_route", func() {
|
||||
routes.DeleteHTTPRoute(r.TargetName())
|
||||
})
|
||||
}
|
||||
|
||||
if common.PrometheusEnabled {
|
||||
r.task.OnFinished("unreg metrics", r.rp.UnregisterMetrics)
|
||||
r.task.OnFinished("metrics_cleanup", r.rp.UnregisterMetrics)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finish implements*task.TaskFinisher.
|
||||
// Task implements task.TaskStarter.
|
||||
func (r *HTTPRoute) Task() *task.Task {
|
||||
return r.task
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (r *HTTPRoute) Finish(reason any) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
@ -168,7 +169,7 @@ func (r *HTTPRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
r.handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) addToLoadBalancer() {
|
||||
func (r *HTTPRoute) addToLoadBalancer(parent task.Parent) {
|
||||
var lb *loadbalancer.LoadBalancer
|
||||
cfg := r.Raw.LoadBalance
|
||||
l, ok := routes.GetHTTPRoute(cfg.Link)
|
||||
|
@ -182,11 +183,7 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
|||
}
|
||||
} else {
|
||||
lb = loadbalancer.New(cfg)
|
||||
lbTask := r.task.Parent().Subtask("loadbalancer " + cfg.Link)
|
||||
lbTask.OnCancel("remove lb from routes", func() {
|
||||
routes.DeleteHTTPRoute(cfg.Link)
|
||||
})
|
||||
if err := lb.Start(lbTask); err != nil {
|
||||
if err := lb.Start(parent); err != nil {
|
||||
panic(err) // should always return nil
|
||||
}
|
||||
linked = &HTTPRoute{
|
||||
|
@ -203,9 +200,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
|||
routes.SetHTTPRoute(cfg.Link, linked)
|
||||
}
|
||||
r.loadBalancer = lb
|
||||
r.server = loadbalance.NewServer(r.task.String(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
r.server = loadbalance.NewServer(r.task.Name(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
lb.AddServer(r.server)
|
||||
r.task.OnCancel("remove server from lb", func() {
|
||||
r.task.OnCancel("lb_remove_server", func() {
|
||||
lb.RemoveServer(r.server)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ func (p *Provider) newEventHandler() *EventHandler {
|
|||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Handle(parent *task.Task, events []watcher.Event) {
|
||||
func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) {
|
||||
oldRoutes := handler.provider.routes
|
||||
newRoutes, err := handler.provider.loadRoutesImpl()
|
||||
if err != nil {
|
||||
|
@ -97,7 +97,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
|
|||
return false
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Add(parent *task.Task, route *route.Route) {
|
||||
func (handler *EventHandler) Add(parent task.Parent, route *route.Route) {
|
||||
err := handler.provider.startRoute(parent, route)
|
||||
if err != nil {
|
||||
handler.errs.Add(err.Subject("add"))
|
||||
|
@ -112,7 +112,7 @@ func (handler *EventHandler) Remove(route *route.Route) {
|
|||
handler.removed.Adds(route.Entry.Alias)
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Update(parent *task.Task, oldRoute *route.Route, newRoute *route.Route) {
|
||||
func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) {
|
||||
oldRoute.Finish("route update")
|
||||
err := handler.provider.startRoute(parent, newRoute)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -60,7 +61,7 @@ func NewFileProvider(filename string) (p *Provider, err error) {
|
|||
if name == "" {
|
||||
return nil, ErrEmptyProviderName
|
||||
}
|
||||
p = newProvider(name, ProviderTypeFile)
|
||||
p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile)
|
||||
p.ProviderImpl, err = FileProviderImpl(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -100,46 +101,43 @@ func (p *Provider) MarshalText() ([]byte, error) {
|
|||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
func (p *Provider) startRoute(parent *task.Task, r *R.Route) E.Error {
|
||||
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
|
||||
err := r.Start(subtask)
|
||||
func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error {
|
||||
err := r.Start(parent)
|
||||
if err != nil {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
subtask.Finish(err) // just to ensure
|
||||
return err.Subject(r.Entry.Alias)
|
||||
}
|
||||
|
||||
p.routes.Store(r.Entry.Alias, r)
|
||||
subtask.OnFinished("del from provider", func() {
|
||||
r.Task().OnFinished("provider_remove_route", func() {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start implements*task.TaskStarter.
|
||||
func (p *Provider) Start(configSubtask *task.Task) E.Error {
|
||||
// routes and event queue will stop on parent cancel
|
||||
providerTask := configSubtask
|
||||
func (p *Provider) Start(parent task.Parent) E.Error {
|
||||
t := parent.Subtask("provider."+p.name, false)
|
||||
|
||||
// routes and event queue will stop on config reload
|
||||
errs := p.routes.CollectErrorsParallel(
|
||||
func(alias string, r *R.Route) error {
|
||||
return p.startRoute(providerTask, r)
|
||||
return p.startRoute(t, r)
|
||||
})
|
||||
|
||||
eventQueue := events.NewEventQueue(
|
||||
providerTask,
|
||||
t.Subtask("event_queue", false),
|
||||
providerEventFlushInterval,
|
||||
func(flushTask *task.Task, events []events.Event) {
|
||||
func(events []events.Event) {
|
||||
handler := p.newEventHandler()
|
||||
// routes' lifetime should follow the provider's lifetime
|
||||
handler.Handle(providerTask, events)
|
||||
handler.Handle(t, events)
|
||||
handler.Log()
|
||||
flushTask.Finish("events flushed")
|
||||
},
|
||||
func(err E.Error) {
|
||||
E.LogError("event error", err, p.Logger())
|
||||
},
|
||||
)
|
||||
eventQueue.Start(p.watcher.Events(providerTask.Context()))
|
||||
eventQueue.Start(p.watcher.Events(t.Context()))
|
||||
|
||||
if err := E.Join(errs...); err != nil {
|
||||
return err.Subject(p.String())
|
||||
|
|
|
@ -47,20 +47,22 @@ func (r *StreamRoute) String() string {
|
|||
return "stream " + r.TargetName()
|
||||
}
|
||||
|
||||
// Start implements*task.TaskStarter.
|
||||
func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *StreamRoute) Start(parent task.Parent) E.Error {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
}
|
||||
|
||||
r.task = providerSubtask
|
||||
r.task = parent.Subtask("stream." + r.TargetName())
|
||||
r.Stream = NewStream(r)
|
||||
|
||||
parent.OnCancel("finish", func() {
|
||||
r.task.Finish(nil)
|
||||
})
|
||||
|
||||
switch {
|
||||
case entry.UseIdleWatcher(r):
|
||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName())
|
||||
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
|
||||
waker, err := idlewatcher.NewStreamWaker(r.task, r.StreamEntry, r.Stream)
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
|
@ -73,7 +75,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
|||
if err == nil {
|
||||
fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck)
|
||||
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
|
||||
r.task.OnCancel("close docker client", client.Close)
|
||||
r.task.OnCancel("close_docker_client", client.Close)
|
||||
}
|
||||
}
|
||||
if r.HealthMon == nil {
|
||||
|
@ -86,7 +88,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
|||
return E.From(err)
|
||||
}
|
||||
|
||||
r.task.OnFinished("close stream", func() {
|
||||
r.task.OnFinished("close_stream", func() {
|
||||
if err := r.Stream.Close(); err != nil {
|
||||
E.LogError("close stream failed", err, &r.l)
|
||||
}
|
||||
|
@ -97,22 +99,26 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
|||
Msg("listening")
|
||||
|
||||
if r.HealthMon != nil {
|
||||
healthMonTask := r.task.Subtask("health monitor")
|
||||
if err := r.HealthMon.Start(healthMonTask); err != nil {
|
||||
if err := r.HealthMon.Start(r.task); err != nil {
|
||||
E.LogWarn("health monitor error", err, &r.l)
|
||||
healthMonTask.Finish(err)
|
||||
}
|
||||
}
|
||||
|
||||
go r.acceptConnections()
|
||||
|
||||
routes.SetStreamRoute(r.TargetName(), r)
|
||||
r.task.OnFinished("remove from route table", func() {
|
||||
r.task.OnFinished("entrypoint_remove_route", func() {
|
||||
routes.DeleteStreamRoute(r.TargetName())
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Task implements task.TaskStarter.
|
||||
func (r *StreamRoute) Task() *task.Task {
|
||||
return r.task
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (r *StreamRoute) Finish(reason any) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
|
|||
return fmt.Errorf("unexpected listener type: %T", stream)
|
||||
}
|
||||
case io.ReadWriteCloser:
|
||||
stream.task.OnCancel("close conn", func() { conn.Close() })
|
||||
stream.task.OnCancel("close_conn", func() { conn.Close() })
|
||||
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
|
||||
|
|
|
@ -4,355 +4,194 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var globalTask = createGlobalTask()
|
||||
|
||||
func createGlobalTask() (t *Task) {
|
||||
t = new(Task)
|
||||
t.name = "root"
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
t.subtasks = F.NewSet[*Task]()
|
||||
return
|
||||
}
|
||||
|
||||
func testResetGlobalTask() {
|
||||
globalTask = createGlobalTask()
|
||||
}
|
||||
|
||||
type (
|
||||
TaskStarter interface {
|
||||
// Start starts the object that implements TaskStarter,
|
||||
// and returns an error if it fails to start.
|
||||
//
|
||||
// The task passed must be a subtask of the caller task.
|
||||
//
|
||||
// callerSubtask.Finish must be called when start fails or the object is finished.
|
||||
Start(callerSubtask *Task) E.Error
|
||||
Start(parent Parent) E.Error
|
||||
Task() *Task
|
||||
}
|
||||
TaskFinisher interface {
|
||||
// Finish marks the task as finished and cancel its context.
|
||||
//
|
||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
||||
// of the task to finish.
|
||||
//
|
||||
// Note that it will also cancel all subtasks.
|
||||
Finish(reason any)
|
||||
}
|
||||
// Task controls objects' lifetime.
|
||||
//
|
||||
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
|
||||
//
|
||||
// When passing a Task object to another function,
|
||||
// it must be a sub-Task of the current Task,
|
||||
// in name of "`currentTaskName`Subtask"
|
||||
//
|
||||
// Use Task.Finish to stop all subtasks of the Task.
|
||||
Task struct {
|
||||
name string
|
||||
|
||||
children sync.WaitGroup
|
||||
|
||||
onFinished sync.WaitGroup
|
||||
finished chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
|
||||
parent *Task
|
||||
subtasks F.Set[*Task]
|
||||
subTasksWg sync.WaitGroup
|
||||
|
||||
name string
|
||||
|
||||
OnFinishedFuncs []func()
|
||||
OnFinishedMu sync.Mutex
|
||||
onFinishedWg sync.WaitGroup
|
||||
|
||||
finishOnce sync.Once
|
||||
once sync.Once
|
||||
}
|
||||
Parent interface {
|
||||
Context() context.Context
|
||||
Subtask(name string, needFinish ...bool) *Task
|
||||
Name() string
|
||||
Finish(reason any)
|
||||
OnCancel(name string, f func())
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProgramExiting = errors.New("program exiting")
|
||||
ErrTaskCanceled = errors.New("task canceled")
|
||||
|
||||
logger = logging.With().Str("module", "task").Logger()
|
||||
)
|
||||
|
||||
// GlobalTask returns a new Task with the given name, derived from the global context.
|
||||
func GlobalTask(format string, args ...any) *Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
return globalTask.Subtask(format)
|
||||
}
|
||||
|
||||
// DebugTaskMap returns a map[string]any representation of the global task tree.
|
||||
//
|
||||
// The returned map is suitable for encoding to JSON, and can be used
|
||||
// to debug the task tree.
|
||||
//
|
||||
// The returned map is not guaranteed to be stable, and may change
|
||||
// between runs of the program. It is intended for debugging purposes
|
||||
// only.
|
||||
func DebugTaskMap() map[string]any {
|
||||
return globalTask.serialize()
|
||||
}
|
||||
|
||||
// CancelGlobalContext cancels the global task context, which will cause all tasks
|
||||
// created to be canceled. This should be called before exiting the program
|
||||
// to ensure that all tasks are properly cleaned up.
|
||||
func CancelGlobalContext() {
|
||||
globalTask.cancel(ErrProgramExiting)
|
||||
}
|
||||
|
||||
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GlobalContextWait(timeout time.Duration) (err error) {
|
||||
done := make(chan struct{})
|
||||
after := time.After(timeout)
|
||||
go func() {
|
||||
globalTask.Wait()
|
||||
close(done)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-after:
|
||||
logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) trace(msg string) {
|
||||
logger.Trace().Str("name", t.name).Msg(msg)
|
||||
}
|
||||
|
||||
// Name returns the name of the task.
|
||||
func (t *Task) Name() string {
|
||||
if !common.IsTrace {
|
||||
return t.name
|
||||
}
|
||||
parts := strings.Split(t.name, " > ")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// String returns the name of the task.
|
||||
func (t *Task) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Context returns the context associated with the task. This context is
|
||||
// canceled when Finish of the task is called, or parent task is canceled.
|
||||
func (t *Task) Context() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
func (t *Task) Finished() <-chan struct{} {
|
||||
return t.finished
|
||||
}
|
||||
|
||||
// FinishCause returns the reason / error that caused the task to be finished.
|
||||
func (t *Task) FinishCause() error {
|
||||
cause := context.Cause(t.ctx)
|
||||
if cause == nil {
|
||||
return t.ctx.Err()
|
||||
}
|
||||
return cause
|
||||
return context.Cause(t.ctx)
|
||||
}
|
||||
|
||||
// Parent returns the parent task of the current task.
|
||||
func (t *Task) Parent() *Task {
|
||||
return t.parent
|
||||
}
|
||||
|
||||
func (t *Task) runAllOnFinished(onCompTask *Task) {
|
||||
<-t.ctx.Done()
|
||||
t.WaitSubTasks()
|
||||
for _, OnFinishedFunc := range t.OnFinishedFuncs {
|
||||
OnFinishedFunc()
|
||||
t.onFinishedWg.Done()
|
||||
}
|
||||
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
|
||||
}
|
||||
|
||||
// OnFinished calls fn when all subtasks are finished.
|
||||
// OnFinished calls fn when the task is canceled and all subtasks are finished.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
// It should not be called after Finish is called.
|
||||
func (t *Task) OnFinished(about string, fn func()) {
|
||||
if t.parent == globalTask {
|
||||
t.OnCancel(about, fn)
|
||||
return
|
||||
}
|
||||
t.onFinishedWg.Add(1)
|
||||
t.OnFinishedMu.Lock()
|
||||
defer t.OnFinishedMu.Unlock()
|
||||
|
||||
if t.OnFinishedFuncs == nil {
|
||||
onCompTask := GlobalTask(t.name + " > OnFinished > " + about)
|
||||
go t.runAllOnFinished(onCompTask)
|
||||
}
|
||||
idx := len(t.OnFinishedFuncs)
|
||||
wrapped := func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error().
|
||||
Str("name", t.name).
|
||||
Interface("err", err).
|
||||
Msg("panic in " + about)
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about)
|
||||
}
|
||||
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
|
||||
t.onCancel(about, fn, true)
|
||||
}
|
||||
|
||||
// OnCancel calls fn when the task is canceled.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
// It should not be called after Finish is called.
|
||||
func (t *Task) OnCancel(about string, fn func()) {
|
||||
onCompTask := GlobalTask(t.name + " > OnFinished")
|
||||
t.onCancel(about, fn, false)
|
||||
}
|
||||
|
||||
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
|
||||
t.onFinished.Add(1)
|
||||
go func() {
|
||||
<-t.ctx.Done()
|
||||
fn()
|
||||
onCompTask.Finish("done")
|
||||
t.trace("onCancel done: " + about)
|
||||
if waitSubTasks {
|
||||
t.children.Wait()
|
||||
}
|
||||
t.invokeWithRecover(fn, about)
|
||||
t.onFinished.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// Finish marks the task as finished and cancel its context.
|
||||
//
|
||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
||||
// of the task to finish.
|
||||
//
|
||||
// Note that it will also cancel all subtasks.
|
||||
// Finish cancel all subtasks and wait for them to finish,
|
||||
// then marks the task as finished, with the given reason (if any).
|
||||
func (t *Task) Finish(reason any) {
|
||||
var format string
|
||||
switch reason.(type) {
|
||||
case error:
|
||||
format = "%w"
|
||||
case string, fmt.Stringer:
|
||||
format = "%s"
|
||||
select {
|
||||
case <-t.finished:
|
||||
return
|
||||
default:
|
||||
format = "%v"
|
||||
}
|
||||
t.finishOnce.Do(func() {
|
||||
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
|
||||
t.once.Do(func() {
|
||||
t.finish(reason)
|
||||
})
|
||||
t.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) finish(reason any) {
|
||||
t.cancel(fmtCause(reason))
|
||||
t.children.Wait()
|
||||
t.onFinished.Wait()
|
||||
if t.finished != nil {
|
||||
close(t.finished)
|
||||
}
|
||||
logger.Trace().Msg("task " + t.name + " finished")
|
||||
}
|
||||
|
||||
func fmtCause(cause any) error {
|
||||
switch cause := cause.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case error:
|
||||
return cause
|
||||
case string:
|
||||
return errors.New(cause)
|
||||
default:
|
||||
return fmt.Errorf("%v", cause)
|
||||
}
|
||||
}
|
||||
|
||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||
//
|
||||
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
|
||||
//
|
||||
// This should not be called after Finish, Wait, or WaitSubTasks is called.
|
||||
func (t *Task) Subtask(name string) *Task {
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
return t.newSubTask(ctx, cancel, name)
|
||||
}
|
||||
// This should not be called after Finish is called.
|
||||
func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||
nf := len(needFinish) == 0 || needFinish[0]
|
||||
|
||||
func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task {
|
||||
parent := t
|
||||
if common.IsTrace {
|
||||
name = parent.name + " > " + name
|
||||
}
|
||||
subtask := &Task{
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
child := &Task{
|
||||
finished: make(chan struct{}, 1),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
name: name,
|
||||
parent: parent,
|
||||
subtasks: F.NewSet[*Task](),
|
||||
}
|
||||
parent.subTasksWg.Add(1)
|
||||
parent.subtasks.Add(subtask)
|
||||
if common.IsTrace {
|
||||
subtask.trace("started")
|
||||
if t != root {
|
||||
child.name = t.name + "." + name
|
||||
allTasks.Add(child)
|
||||
} else {
|
||||
child.name = name
|
||||
}
|
||||
|
||||
allTasksWg.Add(1)
|
||||
t.children.Add(1)
|
||||
|
||||
if !nf {
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
subtask.trace("finished: " + subtask.FinishCause().Error())
|
||||
<-child.ctx.Done()
|
||||
child.Finish(nil)
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
parent.subtasks.Remove(subtask)
|
||||
parent.subTasksWg.Done()
|
||||
<-child.finished
|
||||
allTasksWg.Done()
|
||||
t.children.Done()
|
||||
allTasks.Remove(child)
|
||||
}()
|
||||
return subtask
|
||||
|
||||
logger.Trace().Msg("task " + child.name + " started")
|
||||
return child
|
||||
}
|
||||
|
||||
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
|
||||
//
|
||||
// It must be called only after Finish is called.
|
||||
func (t *Task) Wait() {
|
||||
<-t.ctx.Done()
|
||||
t.WaitSubTasks()
|
||||
t.onFinishedWg.Wait()
|
||||
// Name returns the name of the task without parent names.
|
||||
func (t *Task) Name() string {
|
||||
parts := strutils.SplitRune(t.name, '.')
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// WaitSubTasks waits for all subtasks of the task to finish.
|
||||
//
|
||||
// No more subtasks can be added after this call.
|
||||
//
|
||||
// It can be called before Finish is called.
|
||||
func (t *Task) WaitSubTasks() {
|
||||
t.subTasksWg.Wait()
|
||||
// String returns the full name of the task.
|
||||
func (t *Task) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// tree returns a string representation of the task tree, with the given
|
||||
// prefix prepended to each line. The prefix is used to indent the tree,
|
||||
// and should be a string of spaces or a similar separator.
|
||||
//
|
||||
// The resulting string is suitable for printing to the console, and can be
|
||||
// used to debug the task tree.
|
||||
//
|
||||
// The tree is traversed in a depth-first manner, with each task's name and
|
||||
// line number (if available) printed on a separate line. The line number is
|
||||
// only printed if the task was created with a non-empty line argument.
|
||||
//
|
||||
// The returned string is not guaranteed to be stable, and may change between
|
||||
// runs of the program. It is intended for debugging purposes only.
|
||||
func (t *Task) tree(prefix ...string) string {
|
||||
var sb strings.Builder
|
||||
var pre string
|
||||
if len(prefix) > 0 {
|
||||
pre = prefix[0]
|
||||
sb.WriteString(pre + "- ")
|
||||
}
|
||||
sb.WriteString(t.Name() + "\n")
|
||||
t.subtasks.RangeAll(func(subtask *Task) {
|
||||
sb.WriteString(subtask.tree(pre + " "))
|
||||
})
|
||||
return sb.String()
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (t *Task) MarshalText() ([]byte, error) {
|
||||
return []byte(t.name), nil
|
||||
}
|
||||
|
||||
// serialize returns a map[string]any representation of the task tree.
|
||||
//
|
||||
// The map contains the following keys:
|
||||
// - name: the name of the task
|
||||
// - subtasks: a slice of maps, each representing a subtask
|
||||
//
|
||||
// The subtask maps contain the same keys, recursively.
|
||||
//
|
||||
// The returned map is suitable for encoding to JSON, and can be used
|
||||
// to debug the task tree.
|
||||
//
|
||||
// The returned map is not guaranteed to be stable, and may change
|
||||
// between runs of the program. It is intended for debugging purposes
|
||||
// only.
|
||||
func (t *Task) serialize() map[string]any {
|
||||
m := make(map[string]any)
|
||||
parts := strings.Split(t.name, " > ")
|
||||
m["name"] = parts[len(parts)-1]
|
||||
if t.subtasks.Size() > 0 {
|
||||
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
|
||||
t.subtasks.RangeAll(func(subtask *Task) {
|
||||
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
|
||||
})
|
||||
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error().
|
||||
Interface("err", err).
|
||||
Msg("panic in task " + t.name + "." + caller)
|
||||
if common.IsDebug {
|
||||
panic(string(debug.Stack()))
|
||||
}
|
||||
return m
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}
|
||||
|
|
|
@ -2,132 +2,112 @@ package task
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
const (
|
||||
rootTaskName = "root-task"
|
||||
subTaskName = "subtask"
|
||||
)
|
||||
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
subTask := rootTask.Subtask(subTaskName)
|
||||
|
||||
ExpectEqual(t, rootTaskName, rootTask.Name())
|
||||
ExpectEqual(t, subTaskName, subTask.Name())
|
||||
func testTask() *Task {
|
||||
return RootTask("test", false)
|
||||
}
|
||||
|
||||
func TestTaskCancellation(t *testing.T) {
|
||||
subTaskDone := make(chan struct{})
|
||||
func TestChildTaskCancellation(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
subTask := rootTask.Subtask(subTaskName)
|
||||
parent := testTask()
|
||||
child := parent.Subtask("")
|
||||
|
||||
go func() {
|
||||
subTask.Wait()
|
||||
close(subTaskDone)
|
||||
defer child.Finish(nil)
|
||||
for {
|
||||
select {
|
||||
case <-child.Context().Done():
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go rootTask.Finish(nil)
|
||||
parent.cancel(nil) // should also cancel child
|
||||
|
||||
select {
|
||||
case <-subTaskDone:
|
||||
err := subTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(subTask.Context())
|
||||
ExpectError(t, ErrTaskCanceled, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-child.Finished():
|
||||
ExpectError(t, context.Canceled, child.Context().Err())
|
||||
default:
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnComplete(t *testing.T) {
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
task := rootTask.Subtask(subTaskName)
|
||||
func TestTaskOnCancelOnFinished(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
|
||||
var value atomic.Int32
|
||||
task.OnFinished("set value", func() {
|
||||
value.Store(1234)
|
||||
var shouldTrueOnCancel bool
|
||||
var shouldTrueOnFinish bool
|
||||
|
||||
task.OnCancel("", func() {
|
||||
shouldTrueOnCancel = true
|
||||
})
|
||||
task.OnFinished("", func() {
|
||||
shouldTrueOnFinish = true
|
||||
})
|
||||
|
||||
ExpectFalse(t, shouldTrueOnFinish)
|
||||
task.Finish(nil)
|
||||
ExpectEqual(t, value.Load(), 1234)
|
||||
ExpectTrue(t, shouldTrueOnCancel)
|
||||
ExpectTrue(t, shouldTrueOnFinish)
|
||||
}
|
||||
|
||||
func TestGlobalContextWait(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
defer CancelGlobalContext()
|
||||
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
finished := false
|
||||
|
||||
finished1, finished2 := false, false
|
||||
|
||||
subTask1 := rootTask.Subtask(subTaskName)
|
||||
subTask2 := rootTask.Subtask(subTaskName)
|
||||
subTask1.OnFinished("", func() {
|
||||
finished1 = true
|
||||
})
|
||||
subTask2.OnFinished("", func() {
|
||||
finished2 = true
|
||||
task.OnFinished("", func() {
|
||||
finished = true
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask1.Finish(nil)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask2.Finish(nil)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
subTask1.Wait()
|
||||
subTask2.Wait()
|
||||
rootTask.Finish(nil)
|
||||
}()
|
||||
|
||||
_ = GlobalContextWait(1 * time.Second)
|
||||
ExpectTrue(t, finished1)
|
||||
ExpectTrue(t, finished2)
|
||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
|
||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
|
||||
}
|
||||
|
||||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
rootTask.Subtask(subTaskName)
|
||||
|
||||
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
|
||||
}
|
||||
|
||||
func TestGlobalContextCancellation(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
|
||||
taskDone := make(chan struct{})
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
|
||||
go func() {
|
||||
rootTask.Wait()
|
||||
close(taskDone)
|
||||
}()
|
||||
|
||||
CancelGlobalContext()
|
||||
|
||||
defer task.Finish(nil)
|
||||
for {
|
||||
select {
|
||||
case <-taskDone:
|
||||
err := rootTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(rootTask.Context())
|
||||
ExpectError(t, ErrProgramExiting, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ExpectNoError(t, GracefulShutdown(1*time.Second))
|
||||
ExpectTrue(t, finished)
|
||||
|
||||
<-root.finished
|
||||
ExpectError(t, context.Canceled, task.Context().Err())
|
||||
ExpectError(t, ErrProgramExiting, task.FinishCause())
|
||||
}
|
||||
|
||||
func TestTimeoutOnGracefulShutdown(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
_ = testTask()
|
||||
|
||||
ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
|
||||
}
|
||||
|
||||
func TestFinishMultipleCalls(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
for range 5 {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
task.Finish(nil)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
96
internal/task/utils.go
Normal file
96
internal/task/utils.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
var ErrProgramExiting = errors.New("program exiting")
|
||||
|
||||
var logger = logging.With().Str("module", "task").Logger()
|
||||
|
||||
var root = newRoot()
|
||||
var allTasks = F.NewSet[*Task]()
|
||||
var allTasksWg sync.WaitGroup
|
||||
|
||||
func testCleanup() {
|
||||
root = newRoot()
|
||||
allTasks.Clear()
|
||||
allTasksWg = sync.WaitGroup{}
|
||||
}
|
||||
|
||||
// RootTask returns a new Task with the given name, derived from the root context.
|
||||
func RootTask(name string, needFinish bool) *Task {
|
||||
return root.Subtask(name, needFinish)
|
||||
}
|
||||
|
||||
func newRoot() *Task {
|
||||
t := &Task{name: "root"}
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
return t
|
||||
}
|
||||
|
||||
func RootContext() context.Context {
|
||||
return root.ctx
|
||||
}
|
||||
|
||||
func RootContextCanceled() <-chan struct{} {
|
||||
return root.ctx.Done()
|
||||
}
|
||||
|
||||
func OnProgramExit(about string, fn func()) {
|
||||
root.OnFinished(about, fn)
|
||||
}
|
||||
|
||||
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GracefulShutdown(timeout time.Duration) (err error) {
|
||||
root.cancel(ErrProgramExiting)
|
||||
|
||||
done := make(chan struct{})
|
||||
after := time.After(timeout)
|
||||
|
||||
go func() {
|
||||
allTasksWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-after:
|
||||
b, err := json.Marshal(DebugTaskList())
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to marshal tasks")
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DebugTaskList returns list of all tasks.
|
||||
//
|
||||
// The returned string is suitable for printing to the console.
|
||||
func DebugTaskList() []string {
|
||||
l := make([]string, 0, allTasks.Size())
|
||||
|
||||
allTasks.RangeAll(func(t *Task) {
|
||||
l = append(l, t.name)
|
||||
})
|
||||
|
||||
slices.Sort(l)
|
||||
return l
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type AtomicValue[T any] struct {
|
||||
atomic.Value
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) Load() T {
|
||||
return a.Value.Load().(T)
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) Store(v T) {
|
||||
a.Value.Store(v)
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) Swap(v T) T {
|
||||
return a.Value.Swap(v).(T)
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) CompareAndSwap(oldV, newV T) bool {
|
||||
return a.Value.CompareAndSwap(oldV, newV)
|
||||
}
|
||||
|
||||
func (a *AtomicValue[T]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(a.Load())
|
||||
}
|
30
internal/utils/atomic/atomic_value.go
Normal file
30
internal/utils/atomic/atomic_value.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package atomic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type Value[T any] struct {
|
||||
atomic.Value
|
||||
}
|
||||
|
||||
func (a *Value[T]) Load() T {
|
||||
return a.Value.Load().(T)
|
||||
}
|
||||
|
||||
func (a *Value[T]) Store(v T) {
|
||||
a.Value.Store(v)
|
||||
}
|
||||
|
||||
func (a *Value[T]) Swap(v T) T {
|
||||
return a.Value.Swap(v).(T)
|
||||
}
|
||||
|
||||
func (a *Value[T]) CompareAndSwap(oldV, newV T) bool {
|
||||
return a.Value.CompareAndSwap(oldV, newV)
|
||||
}
|
||||
|
||||
func (a *Value[T]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(a.Load())
|
||||
}
|
|
@ -152,9 +152,10 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
|
|||
return m.CollectErrors(do)
|
||||
}
|
||||
|
||||
errs := make([]error, 0)
|
||||
mu := sync.Mutex{}
|
||||
wg := sync.WaitGroup{}
|
||||
var errs []error
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
m.Range(func(k KT, v VT) bool {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
|
@ -171,24 +172,6 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
|
|||
return errs
|
||||
}
|
||||
|
||||
// RemoveAll removes all key-value pairs from the map where the value matches the given criteria.
|
||||
//
|
||||
// Parameters:
|
||||
//
|
||||
// criteria: function to determine whether a value should be removed
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// nothing
|
||||
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
|
||||
m.Range(func(k KT, v VT) bool {
|
||||
if criteria(v) {
|
||||
m.Delete(k)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (m Map[KT, VT]) Has(k KT) bool {
|
||||
_, ok := m.Load(k)
|
||||
return ok
|
||||
|
|
|
@ -61,3 +61,7 @@ func (set Set[T]) RangeAllParallel(f func(T)) {
|
|||
func (set Set[T]) Size() int {
|
||||
return set.m.Size()
|
||||
}
|
||||
|
||||
func (set Set[T]) IsEmpty() bool {
|
||||
return set.m == nil || set.m.Size() == 0
|
||||
}
|
||||
|
|
|
@ -20,10 +20,17 @@ func IgnoreError[Result any](r Result, _ error) Result {
|
|||
return r
|
||||
}
|
||||
|
||||
func fmtError(err error) string {
|
||||
if err == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return ansi.StripANSI(err.Error())
|
||||
}
|
||||
|
||||
func ExpectNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil && !reflect.ValueOf(err).IsNil() {
|
||||
t.Errorf("expected err=nil, got %s", ansi.StripANSI(err.Error()))
|
||||
if err != nil {
|
||||
t.Errorf("expected err=nil, got %s", fmtError(err))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +38,7 @@ func ExpectNoError(t *testing.T, err error) {
|
|||
func ExpectError(t *testing.T, expected error, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Errorf("expected err %s, got %s", expected, ansi.StripANSI(err.Error()))
|
||||
t.Errorf("expected err %s, got %s", expected, fmtError(err))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +46,7 @@ func ExpectError(t *testing.T, expected error, err error) {
|
|||
func ExpectError2(t *testing.T, input any, expected error, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Errorf("%v: expected err %s, got %s", input, expected, ansi.StripANSI(err.Error()))
|
||||
t.Errorf("%v: expected err %s, got %s", input, expected, fmtError(err))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +55,7 @@ func ExpectErrorT[T error](t *testing.T, err error) {
|
|||
t.Helper()
|
||||
var errAs T
|
||||
if !errors.As(err, &errAs) {
|
||||
t.Errorf("expected err %T, got %s", errAs, ansi.StripANSI(err.Error()))
|
||||
t.Errorf("expected err %T, got %s", errAs, fmtError(err))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,10 @@ var (
|
|||
func NewConfigFileWatcher(filename string) Watcher {
|
||||
configDirWatcherMu.Lock()
|
||||
defer configDirWatcherMu.Unlock()
|
||||
|
||||
if configDirWatcher == nil {
|
||||
configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath)
|
||||
t := task.RootTask("config_dir_watcher", false)
|
||||
configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath)
|
||||
}
|
||||
return configDirWatcher.Add(filename)
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ type DirWatcher struct {
|
|||
//
|
||||
// Note that the returned DirWatcher is not ready to use until the goroutine
|
||||
// started by NewDirectoryWatcher has finished.
|
||||
func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
|
||||
func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
|
||||
//! subdirectories are not watched
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
|
@ -56,7 +56,7 @@ func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
|
|||
fwMap: F.NewMapOf[string, *fileWatcher](),
|
||||
eventCh: make(chan Event),
|
||||
errCh: make(chan E.Error),
|
||||
task: callerSubtask,
|
||||
task: parent.Subtask("dir_watcher(" + dirPath + ")"),
|
||||
}
|
||||
go helper.start()
|
||||
return helper
|
||||
|
@ -80,17 +80,19 @@ func (h *DirWatcher) Add(relPath string) Watcher {
|
|||
eventCh: make(chan Event),
|
||||
errCh: make(chan E.Error),
|
||||
}
|
||||
h.task.OnFinished("close file watcher for "+relPath, func() {
|
||||
close(s.eventCh)
|
||||
close(s.errCh)
|
||||
})
|
||||
h.fwMap.Store(relPath, s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (h *DirWatcher) cleanup() {
|
||||
h.w.Close()
|
||||
close(h.eventCh)
|
||||
close(h.errCh)
|
||||
h.task.Finish(nil)
|
||||
}
|
||||
|
||||
func (h *DirWatcher) start() {
|
||||
defer close(h.eventCh)
|
||||
defer h.w.Close()
|
||||
defer h.cleanup()
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package events
|
||||
|
||||
import (
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
|
@ -17,7 +18,7 @@ type (
|
|||
onFlush OnFlushFunc
|
||||
onError OnErrorFunc
|
||||
}
|
||||
OnFlushFunc = func(flushTask *task.Task, events []Event)
|
||||
OnFlushFunc = func(events []Event)
|
||||
OnErrorFunc = func(err E.Error)
|
||||
)
|
||||
|
||||
|
@ -38,9 +39,9 @@ const eventQueueCapacity = 10
|
|||
// but the onFlush function can return earlier (e.g. run in another goroutine).
|
||||
//
|
||||
// If task is canceled before the flushInterval is reached, the events in queue will be discarded.
|
||||
func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
|
||||
func NewEventQueue(queueTask *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
|
||||
return &EventQueue{
|
||||
task: parent.Subtask("event queue"),
|
||||
task: queueTask,
|
||||
queue: make([]Event, 0, eventQueueCapacity),
|
||||
ticker: time.NewTicker(flushInterval),
|
||||
flushInterval: flushInterval,
|
||||
|
@ -50,19 +51,20 @@ func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlu
|
|||
}
|
||||
|
||||
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
||||
if common.IsProduction {
|
||||
origOnFlush := e.onFlush
|
||||
// recover panic in onFlush when in production mode
|
||||
e.onFlush = func(flushTask *task.Task, events []Event) {
|
||||
e.onFlush = func(events []Event) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
e.onError(E.New("recovered panic in onFlush").
|
||||
Withf("%v", err).
|
||||
Subject(e.task.Parent().String()))
|
||||
Subject(e.task.Name()))
|
||||
if common.IsDebug {
|
||||
panic(string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
origOnFlush(flushTask, events)
|
||||
}
|
||||
origOnFlush(events)
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
@ -75,19 +77,24 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
|||
return
|
||||
case <-e.ticker.C:
|
||||
if len(e.queue) > 0 {
|
||||
flushTask := e.task.Subtask("flush events")
|
||||
queue := e.queue
|
||||
e.queue = make([]Event, 0, eventQueueCapacity)
|
||||
go e.onFlush(flushTask, queue)
|
||||
flushTask.Wait()
|
||||
// clone -> clear -> flush
|
||||
queue := make([]Event, len(e.queue))
|
||||
copy(queue, e.queue)
|
||||
|
||||
e.queue = e.queue[:0]
|
||||
|
||||
e.onFlush(queue)
|
||||
}
|
||||
e.ticker.Reset(e.flushInterval)
|
||||
case event, ok := <-eventCh:
|
||||
e.queue = append(e.queue, event)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case err := <-errCh:
|
||||
e.queue = append(e.queue, event)
|
||||
case err, ok := <-errCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
e.onError(err)
|
||||
}
|
||||
|
@ -95,10 +102,3 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
|||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait waits for all events to be flushed and the task to finish.
|
||||
//
|
||||
// It is safe to call this method multiple times.
|
||||
func (e *EventQueue) Wait() {
|
||||
e.task.Wait()
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
@ -23,9 +23,9 @@ type (
|
|||
monitor struct {
|
||||
service string
|
||||
config *health.HealthCheckConfig
|
||||
url U.AtomicValue[types.URL]
|
||||
url atomic.Value[types.URL]
|
||||
|
||||
status U.AtomicValue[health.Status]
|
||||
status atomic.Value[health.Status]
|
||||
lastResult *health.HealthCheckResult
|
||||
lastSeen time.Time
|
||||
|
||||
|
@ -59,10 +59,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
||||
mon.service = routeSubtask.Parent().Name()
|
||||
mon.task = routeSubtask
|
||||
|
||||
func (mon *monitor) Start(parent task.Parent) E.Error {
|
||||
if mon.config.Interval <= 0 {
|
||||
return E.From(ErrNegativeInterval)
|
||||
}
|
||||
|
@ -71,6 +68,9 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
|||
mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service))
|
||||
}
|
||||
|
||||
mon.service = parent.Name()
|
||||
mon.task = parent.Subtask("health_monitor")
|
||||
|
||||
go func() {
|
||||
logger := logging.With().Str("name", mon.service).Logger()
|
||||
|
||||
|
@ -78,10 +78,10 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
|||
if mon.status.Load() != health.StatusError {
|
||||
mon.status.Store(health.StatusUnknown)
|
||||
}
|
||||
mon.task.Finish(nil)
|
||||
if mon.metric != nil {
|
||||
mon.metric.Reset()
|
||||
}
|
||||
mon.task.Finish(nil)
|
||||
}()
|
||||
|
||||
if err := mon.checkUpdateHealth(); err != nil {
|
||||
|
@ -108,6 +108,11 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Task implements task.TaskStarter.
|
||||
func (mon *monitor) Task() *task.Task {
|
||||
return mon.task
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (mon *monitor) Finish(reason any) {
|
||||
mon.task.Finish(reason)
|
||||
|
|
Loading…
Add table
Reference in a new issue