mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 04:42:33 +02:00
simplify task package implementation
This commit is contained in:
parent
e7aaa95ec5
commit
1ab34ed46f
35 changed files with 547 additions and 600 deletions
|
@ -23,16 +23,16 @@ lint:
|
||||||
enabled:
|
enabled:
|
||||||
- hadolint@2.12.1-beta
|
- hadolint@2.12.1-beta
|
||||||
- actionlint@1.7.4
|
- actionlint@1.7.4
|
||||||
- checkov@3.2.334
|
- checkov@3.2.344
|
||||||
- git-diff-check
|
- git-diff-check
|
||||||
- gofmt@1.20.4
|
- gofmt@1.20.4
|
||||||
- golangci-lint@1.62.2
|
- golangci-lint@1.62.2
|
||||||
- osv-scanner@1.9.1
|
- osv-scanner@1.9.2
|
||||||
- oxipng@9.1.3
|
- oxipng@9.1.3
|
||||||
- prettier@3.4.2
|
- prettier@3.4.2
|
||||||
- shellcheck@0.10.0
|
- shellcheck@0.10.0
|
||||||
- shfmt@3.6.0
|
- shfmt@3.6.0
|
||||||
- trufflehog@3.86.1
|
- trufflehog@3.88.0
|
||||||
actions:
|
actions:
|
||||||
disabled:
|
disabled:
|
||||||
- trunk-announce
|
- trunk-announce
|
||||||
|
|
4
Makefile
4
Makefile
|
@ -28,10 +28,10 @@ get:
|
||||||
go get -u ./cmd && go mod tidy
|
go get -u ./cmd && go mod tidy
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
GODOXY_DEBUG=1 make run
|
GODOXY_DEBUG=1 BUILD_FLAGS="" make run
|
||||||
|
|
||||||
debug-trace:
|
debug-trace:
|
||||||
GODOXY_DEBUG=1 GODOXY_TRACE=1 run
|
GODOXY_TRACE=1 make debug
|
||||||
|
|
||||||
profile:
|
profile:
|
||||||
GODEBUG=gctrace=1 make debug
|
GODEBUG=gctrace=1 make debug
|
||||||
|
|
|
@ -159,8 +159,7 @@ func main() {
|
||||||
|
|
||||||
// grafully shutdown
|
// grafully shutdown
|
||||||
logging.Info().Msg("shutting down")
|
logging.Info().Msg("shutting down")
|
||||||
task.CancelGlobalContext()
|
_ = task.GracefulShutdown(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
||||||
_ = task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareDirectory(dir string) {
|
func prepareDirectory(dir string) {
|
||||||
|
|
|
@ -52,7 +52,7 @@ func List(w http.ResponseWriter, r *http.Request) {
|
||||||
case ListHomepageConfig:
|
case ListHomepageConfig:
|
||||||
U.RespondJSON(w, r, config.HomepageConfig())
|
U.RespondJSON(w, r, config.HomepageConfig())
|
||||||
case ListTasks:
|
case ListTasks:
|
||||||
U.RespondJSON(w, r, task.DebugTaskMap())
|
U.RespondJSON(w, r, task.DebugTaskList())
|
||||||
default:
|
default:
|
||||||
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,8 +153,8 @@ func (p *Provider) ScheduleRenewal() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
task := task.GlobalTask("cert renew scheduler")
|
task := task.RootTask("cert-renew-scheduler", true)
|
||||||
defer task.Finish("cert renew scheduler stopped")
|
defer task.Finish(nil)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
renewalTime := p.ShouldRenewOn()
|
renewalTime := p.ShouldRenewOn()
|
||||||
|
|
|
@ -53,7 +53,7 @@ func newConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
value: types.DefaultConfig(),
|
value: types.DefaultConfig(),
|
||||||
providers: F.NewMapOf[string, *proxy.Provider](),
|
providers: F.NewMapOf[string, *proxy.Provider](),
|
||||||
task: task.GlobalTask("config"),
|
task: task.RootTask("config", false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,21 +76,19 @@ func MatchDomains() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func WatchChanges() {
|
func WatchChanges() {
|
||||||
task := task.GlobalTask("config watcher")
|
t := task.RootTask("config_watcher", true)
|
||||||
eventQueue := events.NewEventQueue(
|
eventQueue := events.NewEventQueue(
|
||||||
task,
|
t,
|
||||||
configEventFlushInterval,
|
configEventFlushInterval,
|
||||||
OnConfigChange,
|
OnConfigChange,
|
||||||
func(err E.Error) {
|
func(err E.Error) {
|
||||||
E.LogError("config reload error", err, &logger)
|
E.LogError("config reload error", err, &logger)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eventQueue.Start(cfgWatcher.Events(task.Context()))
|
eventQueue.Start(cfgWatcher.Events(t.Context()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func OnConfigChange(flushTask *task.Task, ev []events.Event) {
|
func OnConfigChange(ev []events.Event) {
|
||||||
defer flushTask.Finish("config reload complete")
|
|
||||||
|
|
||||||
// no matter how many events during the interval
|
// no matter how many events during the interval
|
||||||
// just reload once and check the last event
|
// just reload once and check the last event
|
||||||
switch ev[len(ev)-1].Action {
|
switch ev[len(ev)-1].Action {
|
||||||
|
@ -116,14 +114,14 @@ func Reload() E.Error {
|
||||||
newCfg := newConfig()
|
newCfg := newConfig()
|
||||||
err := newCfg.load()
|
err := newCfg.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
newCfg.task.Finish(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// cancel all current subtasks -> wait
|
// cancel all current subtasks -> wait
|
||||||
// -> replace config -> start new subtasks
|
// -> replace config -> start new subtasks
|
||||||
instance.task.Finish("config changed")
|
instance.task.Finish("config changed")
|
||||||
instance.task.Wait()
|
instance = newCfg
|
||||||
*instance = *newCfg
|
|
||||||
instance.StartProxyProviders()
|
instance.StartProxyProviders()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -143,8 +141,7 @@ func (cfg *Config) Task() *task.Task {
|
||||||
func (cfg *Config) StartProxyProviders() {
|
func (cfg *Config) StartProxyProviders() {
|
||||||
errs := cfg.providers.CollectErrorsParallel(
|
errs := cfg.providers.CollectErrorsParallel(
|
||||||
func(_ string, p *proxy.Provider) error {
|
func(_ string, p *proxy.Provider) error {
|
||||||
subtask := cfg.task.Subtask(p.String())
|
return p.Start(cfg.task)
|
||||||
return p.Start(subtask)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := E.Join(errs...); err != nil {
|
if err := E.Join(errs...); err != nil {
|
||||||
|
@ -209,9 +206,6 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
|
func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
|
||||||
subtask := cfg.task.Subtask("load route providers")
|
|
||||||
defer subtask.Finish("done")
|
|
||||||
|
|
||||||
errs := E.NewBuilder("route provider errors")
|
errs := E.NewBuilder("route provider errors")
|
||||||
results := E.NewBuilder("loaded route providers")
|
results := E.NewBuilder("loaded route providers")
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
task.GlobalTask("close docker clients").OnFinished("", func() {
|
task.OnProgramExit("docker_clients_cleanup", func() {
|
||||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||||
if c.Connected() {
|
if c.Connected() {
|
||||||
c.Client.Close()
|
c.Client.Close()
|
||||||
|
|
|
@ -38,7 +38,7 @@ const (
|
||||||
|
|
||||||
// TODO: support stream
|
// TODO: support stream
|
||||||
|
|
||||||
func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
|
func newWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
|
||||||
hcCfg := entry.RawEntry().HealthCheck
|
hcCfg := entry.RawEntry().HealthCheck
|
||||||
hcCfg.Timeout = idleWakerCheckTimeout
|
hcCfg.Timeout = idleWakerCheckTimeout
|
||||||
|
|
||||||
|
@ -46,8 +46,8 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
|
||||||
rp: rp,
|
rp: rp,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
}
|
}
|
||||||
|
task := parent.Subtask("idlewatcher")
|
||||||
watcher, err := registerWatcher(providerSubTask, entry, waker)
|
watcher, err := registerWatcher(task, entry, waker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.Errorf("register watcher: %w", err)
|
return nil, E.Errorf("register watcher: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,7 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
|
||||||
|
|
||||||
if common.PrometheusEnabled {
|
if common.PrometheusEnabled {
|
||||||
m := metrics.GetServiceMetrics()
|
m := metrics.GetServiceMetrics()
|
||||||
fqn := providerSubTask.Parent().Name() + "/" + entry.TargetName()
|
fqn := parent.Name() + "/" + entry.TargetName()
|
||||||
waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn))
|
waker.metric = m.HealthStatus.With(metrics.HealthMetricLabels(fqn))
|
||||||
waker.metric.Set(float64(watcher.Status()))
|
waker.metric.Set(float64(watcher.Status()))
|
||||||
}
|
}
|
||||||
|
@ -71,19 +71,18 @@ func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseP
|
||||||
}
|
}
|
||||||
|
|
||||||
// lifetime should follow route provider.
|
// lifetime should follow route provider.
|
||||||
func NewHTTPWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
func NewHTTPWaker(parent task.Parent, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
||||||
return newWaker(providerSubTask, entry, rp, nil)
|
return newWaker(parent, entry, rp, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStreamWaker(providerSubTask *task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) {
|
func NewStreamWaker(parent task.Parent, entry route.Entry, stream net.Stream) (Waker, E.Error) {
|
||||||
return newWaker(providerSubTask, entry, nil, stream)
|
return newWaker(parent, entry, nil, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements health.HealthMonitor.
|
// Start implements health.HealthMonitor.
|
||||||
func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
|
func (w *Watcher) Start(parent task.Parent) E.Error {
|
||||||
routeSubTask.Finish("ignored")
|
w.task.OnCancel("route_cleanup", func() {
|
||||||
w.task.OnCancel("stop route and cleanup", func() {
|
parent.Finish(w.task.FinishCause())
|
||||||
routeSubTask.Parent().Finish(w.task.FinishCause())
|
|
||||||
if w.metric != nil {
|
if w.metric != nil {
|
||||||
w.metric.Reset()
|
w.metric.Reset()
|
||||||
}
|
}
|
||||||
|
@ -91,6 +90,11 @@ func (w *Watcher) Start(routeSubTask *task.Task) E.Error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Task implements health.HealthMonitor.
|
||||||
|
func (w *Watcher) Task() *task.Task {
|
||||||
|
return w.task
|
||||||
|
}
|
||||||
|
|
||||||
// Finish implements health.HealthMonitor.
|
// Finish implements health.HealthMonitor.
|
||||||
func (w *Watcher) Finish(reason any) {
|
func (w *Watcher) Finish(reason any) {
|
||||||
if w.stream != nil {
|
if w.stream != nil {
|
||||||
|
|
|
@ -51,7 +51,7 @@ var (
|
||||||
|
|
||||||
const dockerReqTimeout = 3 * time.Second
|
const dockerReqTimeout = 3 * time.Second
|
||||||
|
|
||||||
func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
|
func registerWatcher(watcherTask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) {
|
||||||
cfg := entry.IdlewatcherConfig()
|
cfg := entry.IdlewatcherConfig()
|
||||||
|
|
||||||
if cfg.IdleTimeout == 0 {
|
if cfg.IdleTimeout == 0 {
|
||||||
|
@ -67,7 +67,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
|
||||||
w.Config = cfg
|
w.Config = cfg
|
||||||
w.waker = waker
|
w.waker = waker
|
||||||
w.resetIdleTimer()
|
w.resetIdleTimer()
|
||||||
providerSubtask.Finish("used existing watcher")
|
watcherTask.Finish("used existing watcher")
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
waker: waker,
|
waker: waker,
|
||||||
client: client,
|
client: client,
|
||||||
task: providerSubtask,
|
task: watcherTask,
|
||||||
ticker: time.NewTicker(cfg.IdleTimeout),
|
ticker: time.NewTicker(cfg.IdleTimeout),
|
||||||
}
|
}
|
||||||
w.stopByMethod = w.getStopCallback()
|
w.stopByMethod = w.getStopCallback()
|
||||||
|
@ -210,8 +210,7 @@ func (w *Watcher) resetIdleTimer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
|
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
|
||||||
eventTask = w.task.Subtask("docker event watcher")
|
eventCh, errCh = dockerWatcher.EventsWithOptions(w.Task().Context(), watcher.DockerListOptions{
|
||||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{
|
|
||||||
Filters: watcher.NewDockerFilter(
|
Filters: watcher.NewDockerFilter(
|
||||||
watcher.DockerFilterContainer,
|
watcher.DockerFilterContainer,
|
||||||
watcher.DockerFilterContainerNameID(w.ContainerID),
|
watcher.DockerFilterContainerNameID(w.ContainerID),
|
||||||
|
|
|
@ -54,7 +54,7 @@ func SetMiddlewares(mws []map[string]any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetAccessLogger(parent *task.Task, cfg *accesslog.Config) (err error) {
|
func SetAccessLogger(parent task.Parent, cfg *accesslog.Config) (err error) {
|
||||||
epAccessLoggerMu.Lock()
|
epAccessLoggerMu.Lock()
|
||||||
defer epAccessLoggerMu.Unlock()
|
defer epAccessLoggerMu.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -50,10 +50,12 @@ func Join(errors ...error) Error {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
errs := make([]error, 0, n)
|
errs := make([]error, n)
|
||||||
|
i := 0
|
||||||
for _, err := range errors {
|
for _, err := range errors {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs = append(errs, err)
|
errs[i] = err
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &nestedError{Extras: errs}
|
return &nestedError{Extras: errs}
|
||||||
|
|
|
@ -22,6 +22,7 @@ type (
|
||||||
Path string `json:"path" validate:"required"`
|
Path string `json:"path" validate:"required"`
|
||||||
Filters Filters `json:"filters"`
|
Filters Filters `json:"filters"`
|
||||||
Fields Fields `json:"fields"`
|
Fields Fields `json:"fields"`
|
||||||
|
// Retention *Retention
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,7 +32,7 @@ var (
|
||||||
FormatJSON Format = "json"
|
FormatJSON Format = "json"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultBufferSize = 100
|
const DefaultBufferSize = 64 * 1024 // 64KB
|
||||||
|
|
||||||
func DefaultConfig() *Config {
|
func DefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
|
|
22
internal/net/http/accesslog/file_logger.go
Normal file
22
internal/net/http/accesslog/file_logger.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package accesslog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
)
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
*os.File
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFileAccessLogger(parent task.Parent, cfg *Config) (*AccessLogger, error) {
|
||||||
|
f, err := os.OpenFile(cfg.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("access log open error: %w", err)
|
||||||
|
}
|
||||||
|
return NewAccessLogger(parent, &File{File: f}, cfg), nil
|
||||||
|
}
|
|
@ -7,18 +7,17 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
CommonFormatter struct {
|
CommonFormatter struct {
|
||||||
cfg *Fields
|
cfg *Fields
|
||||||
}
|
GetTimeNow func() time.Time // for testing purposes only
|
||||||
CombinedFormatter struct {
|
|
||||||
CommonFormatter
|
|
||||||
}
|
|
||||||
JSONFormatter struct {
|
|
||||||
CommonFormatter
|
|
||||||
}
|
}
|
||||||
|
CombinedFormatter CommonFormatter
|
||||||
|
JSONFormatter CommonFormatter
|
||||||
|
|
||||||
JSONLogEntry struct {
|
JSONLogEntry struct {
|
||||||
Time string `json:"time"`
|
Time string `json:"time"`
|
||||||
IP string `json:"ip"`
|
IP string `json:"ip"`
|
||||||
|
@ -39,6 +38,8 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const LogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||||
|
|
||||||
func scheme(req *http.Request) string {
|
func scheme(req *http.Request) string {
|
||||||
if req.TLS != nil {
|
if req.TLS != nil {
|
||||||
return "https"
|
return "https"
|
||||||
|
@ -62,7 +63,7 @@ func clientIP(req *http.Request) string {
|
||||||
return req.RemoteAddr
|
return req.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
func (f *CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||||
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
||||||
|
|
||||||
line.WriteString(req.Host)
|
line.WriteString(req.Host)
|
||||||
|
@ -71,7 +72,7 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
|
||||||
line.WriteString(clientIP(req))
|
line.WriteString(clientIP(req))
|
||||||
line.WriteString(" - - [")
|
line.WriteString(" - - [")
|
||||||
|
|
||||||
line.WriteString(timeNow())
|
line.WriteString(f.GetTimeNow().Format(LogTimeFormat))
|
||||||
line.WriteString("] \"")
|
line.WriteString("] \"")
|
||||||
|
|
||||||
line.WriteString(req.Method)
|
line.WriteString(req.Method)
|
||||||
|
@ -86,8 +87,8 @@ func (f CommonFormatter) Format(line *bytes.Buffer, req *http.Request, res *http
|
||||||
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
|
line.WriteString(strconv.FormatInt(res.ContentLength, 10))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
func (f *CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||||
f.CommonFormatter.Format(line, req, res)
|
(*CommonFormatter)(f).Format(line, req, res)
|
||||||
line.WriteString(" \"")
|
line.WriteString(" \"")
|
||||||
line.WriteString(req.Referer())
|
line.WriteString(req.Referer())
|
||||||
line.WriteString("\" \"")
|
line.WriteString("\" \"")
|
||||||
|
@ -95,14 +96,14 @@ func (f CombinedFormatter) Format(line *bytes.Buffer, req *http.Request, res *ht
|
||||||
line.WriteRune('"')
|
line.WriteRune('"')
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
func (f *JSONFormatter) Format(line *bytes.Buffer, req *http.Request, res *http.Response) {
|
||||||
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
query := f.cfg.Query.ProcessQuery(req.URL.Query())
|
||||||
headers := f.cfg.Headers.ProcessHeaders(req.Header)
|
headers := f.cfg.Headers.ProcessHeaders(req.Header)
|
||||||
headers.Del("Cookie")
|
headers.Del("Cookie")
|
||||||
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
|
cookies := f.cfg.Cookies.ProcessCookies(req.Cookies())
|
||||||
|
|
||||||
entry := JSONLogEntry{
|
entry := JSONLogEntry{
|
||||||
Time: timeNow(),
|
Time: f.GetTimeNow().Format(LogTimeFormat),
|
||||||
IP: clientIP(req),
|
IP: clientIP(req),
|
||||||
Method: req.Method,
|
Method: req.Method,
|
||||||
Scheme: scheme(req),
|
Scheme: scheme(req),
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
|
||||||
|
"github.com/yusing/go-proxy/internal/route/routes"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||||
|
@ -52,10 +53,13 @@ func New(cfg *Config) *LoadBalancer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements task.TaskStarter.
|
// Start implements task.TaskStarter.
|
||||||
func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
|
func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
|
||||||
lb.startTime = time.Now()
|
lb.startTime = time.Now()
|
||||||
lb.task = routeSubtask
|
lb.task = parent.Subtask("loadbalancer."+lb.Link, false)
|
||||||
lb.task.OnFinished("loadbalancer cleanup", func() {
|
parent.OnCancel("lb_remove_route", func() {
|
||||||
|
routes.DeleteHTTPRoute(lb.Link)
|
||||||
|
})
|
||||||
|
lb.task.OnFinished("cleanup", func() {
|
||||||
if lb.impl != nil {
|
if lb.impl != nil {
|
||||||
lb.pool.RangeAll(func(k string, v *Server) {
|
lb.pool.RangeAll(func(k string, v *Server) {
|
||||||
lb.impl.OnRemoveServer(v)
|
lb.impl.OnRemoveServer(v)
|
||||||
|
@ -66,6 +70,11 @@ func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Task implements task.TaskStarter.
|
||||||
|
func (lb *LoadBalancer) Task() *task.Task {
|
||||||
|
return lb.task
|
||||||
|
}
|
||||||
|
|
||||||
// Finish implements task.TaskFinisher.
|
// Finish implements task.TaskFinisher.
|
||||||
func (lb *LoadBalancer) Finish(reason any) {
|
func (lb *LoadBalancer) Finish(reason any) {
|
||||||
lb.task.Finish(reason)
|
lb.task.Finish(reason)
|
||||||
|
|
|
@ -32,10 +32,10 @@ func setup() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
task := task.GlobalTask("error page")
|
t := task.RootTask("error_page", true)
|
||||||
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
|
dirWatcher = W.NewDirectoryWatcher(t, errPagesBasePath)
|
||||||
loadContent()
|
loadContent()
|
||||||
go watchDir(task)
|
go watchDir()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetStaticFile(filename string) ([]byte, bool) {
|
func GetStaticFile(filename string) ([]byte, bool) {
|
||||||
|
@ -73,11 +73,11 @@ func loadContent() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func watchDir(task *task.Task) {
|
func watchDir() {
|
||||||
eventCh, errCh := dirWatcher.Events(task.Context())
|
eventCh, errCh := dirWatcher.Events(task.RootContext())
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-task.Context().Done():
|
case <-task.RootContextCanceled():
|
||||||
return
|
return
|
||||||
case event, ok := <-eventCh:
|
case event, ok := <-eventCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
@ -24,8 +24,6 @@ type Server struct {
|
||||||
httpsStarted bool
|
httpsStarted bool
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
|
||||||
task *task.Task
|
|
||||||
|
|
||||||
l zerolog.Logger
|
l zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +74,6 @@ func NewServer(opt Options) (s *Server) {
|
||||||
CertProvider: opt.CertProvider,
|
CertProvider: opt.CertProvider,
|
||||||
http: httpSer,
|
http: httpSer,
|
||||||
https: httpsSer,
|
https: httpsSer,
|
||||||
task: task.GlobalTask(opt.Name + " server"),
|
|
||||||
l: logger,
|
l: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,7 +105,7 @@ func (s *Server) Start() {
|
||||||
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
|
s.l.Info().Str("addr", s.https.Addr).Msgf("server started")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.task.OnFinished("stop server", s.stop)
|
task.OnProgramExit("server."+s.Name+".stop", s.stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) stop() {
|
func (s *Server) stop() {
|
||||||
|
@ -117,12 +114,12 @@ func (s *Server) stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.http != nil && s.httpStarted {
|
if s.http != nil && s.httpStarted {
|
||||||
s.handleErr("http", s.http.Shutdown(s.task.Context()))
|
s.handleErr("http", s.http.Shutdown(task.RootContext()))
|
||||||
s.httpStarted = false
|
s.httpStarted = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.https != nil && s.httpsStarted {
|
if s.https != nil && s.httpsStarted {
|
||||||
s.handleErr("https", s.https.Shutdown(s.task.Context()))
|
s.handleErr("https", s.https.Shutdown(task.RootContext()))
|
||||||
s.httpsStarted = false
|
s.httpsStarted = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ var (
|
||||||
|
|
||||||
const dispatchErr = "notification dispatch error"
|
const dispatchErr = "notification dispatch error"
|
||||||
|
|
||||||
func StartNotifDispatcher(parent *task.Task) *Dispatcher {
|
func StartNotifDispatcher(parent task.Parent) *Dispatcher {
|
||||||
dispatcher = &Dispatcher{
|
dispatcher = &Dispatcher{
|
||||||
task: parent.Subtask("notification"),
|
task: parent.Subtask("notification"),
|
||||||
logCh: make(chan *LogMessage),
|
logCh: make(chan *LogMessage),
|
||||||
|
|
|
@ -73,19 +73,17 @@ func (r *HTTPRoute) String() string {
|
||||||
return r.TargetName()
|
return r.TargetName()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements*task.TaskStarter.
|
// Start implements task.TaskStarter.
|
||||||
func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
|
func (r *HTTPRoute) Start(parent task.Parent) E.Error {
|
||||||
if entry.ShouldNotServe(r) {
|
if entry.ShouldNotServe(r) {
|
||||||
providerSubtask.Finish("should not serve")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r.task = providerSubtask
|
r.task = parent.Subtask("http."+r.TargetName(), false)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case entry.UseIdleWatcher(r):
|
case entry.UseIdleWatcher(r):
|
||||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName())
|
waker, err := idlewatcher.NewHTTPWaker(r.task, r.ReverseProxyEntry, r.rp)
|
||||||
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.task.Finish(err)
|
r.task.Finish(err)
|
||||||
return err
|
return err
|
||||||
|
@ -98,7 +96,7 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck)
|
fallback := monitor.NewHTTPHealthChecker(r.rp.TargetURL, r.Raw.HealthCheck)
|
||||||
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
|
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
|
||||||
r.task.OnCancel("close docker client", client.Close)
|
r.task.OnCancel("close_docker_client", client.Close)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if r.HealthMon == nil {
|
if r.HealthMon == nil {
|
||||||
|
@ -137,29 +135,32 @@ func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.HealthMon != nil {
|
if r.HealthMon != nil {
|
||||||
healthMonTask := r.task.Subtask("health monitor")
|
if err := r.HealthMon.Start(r.task); err != nil {
|
||||||
if err := r.HealthMon.Start(healthMonTask); err != nil {
|
|
||||||
E.LogWarn("health monitor error", err, &r.l)
|
E.LogWarn("health monitor error", err, &r.l)
|
||||||
healthMonTask.Finish(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.UseLoadBalance(r) {
|
if entry.UseLoadBalance(r) {
|
||||||
r.addToLoadBalancer()
|
r.addToLoadBalancer(parent)
|
||||||
} else {
|
} else {
|
||||||
routes.SetHTTPRoute(r.TargetName(), r)
|
routes.SetHTTPRoute(r.TargetName(), r)
|
||||||
r.task.OnFinished("remove from route table", func() {
|
r.task.OnFinished("entrypoint_remove_route", func() {
|
||||||
routes.DeleteHTTPRoute(r.TargetName())
|
routes.DeleteHTTPRoute(r.TargetName())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.PrometheusEnabled {
|
if common.PrometheusEnabled {
|
||||||
r.task.OnFinished("unreg metrics", r.rp.UnregisterMetrics)
|
r.task.OnFinished("metrics_cleanup", r.rp.UnregisterMetrics)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish implements*task.TaskFinisher.
|
// Task implements task.TaskStarter.
|
||||||
|
func (r *HTTPRoute) Task() *task.Task {
|
||||||
|
return r.task
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish implements task.TaskFinisher.
|
||||||
func (r *HTTPRoute) Finish(reason any) {
|
func (r *HTTPRoute) Finish(reason any) {
|
||||||
r.task.Finish(reason)
|
r.task.Finish(reason)
|
||||||
}
|
}
|
||||||
|
@ -168,7 +169,7 @@ func (r *HTTPRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
r.handler.ServeHTTP(w, req)
|
r.handler.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HTTPRoute) addToLoadBalancer() {
|
func (r *HTTPRoute) addToLoadBalancer(parent task.Parent) {
|
||||||
var lb *loadbalancer.LoadBalancer
|
var lb *loadbalancer.LoadBalancer
|
||||||
cfg := r.Raw.LoadBalance
|
cfg := r.Raw.LoadBalance
|
||||||
l, ok := routes.GetHTTPRoute(cfg.Link)
|
l, ok := routes.GetHTTPRoute(cfg.Link)
|
||||||
|
@ -182,11 +183,7 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
lb = loadbalancer.New(cfg)
|
lb = loadbalancer.New(cfg)
|
||||||
lbTask := r.task.Parent().Subtask("loadbalancer " + cfg.Link)
|
if err := lb.Start(parent); err != nil {
|
||||||
lbTask.OnCancel("remove lb from routes", func() {
|
|
||||||
routes.DeleteHTTPRoute(cfg.Link)
|
|
||||||
})
|
|
||||||
if err := lb.Start(lbTask); err != nil {
|
|
||||||
panic(err) // should always return nil
|
panic(err) // should always return nil
|
||||||
}
|
}
|
||||||
linked = &HTTPRoute{
|
linked = &HTTPRoute{
|
||||||
|
@ -203,9 +200,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
||||||
routes.SetHTTPRoute(cfg.Link, linked)
|
routes.SetHTTPRoute(cfg.Link, linked)
|
||||||
}
|
}
|
||||||
r.loadBalancer = lb
|
r.loadBalancer = lb
|
||||||
r.server = loadbalance.NewServer(r.task.String(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
|
r.server = loadbalance.NewServer(r.task.Name(), r.rp.TargetURL, r.Raw.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||||
lb.AddServer(r.server)
|
lb.AddServer(r.server)
|
||||||
r.task.OnCancel("remove server from lb", func() {
|
r.task.OnCancel("lb_remove_server", func() {
|
||||||
lb.RemoveServer(r.server)
|
lb.RemoveServer(r.server)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ func (p *Provider) newEventHandler() *EventHandler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *EventHandler) Handle(parent *task.Task, events []watcher.Event) {
|
func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) {
|
||||||
oldRoutes := handler.provider.routes
|
oldRoutes := handler.provider.routes
|
||||||
newRoutes, err := handler.provider.loadRoutesImpl()
|
newRoutes, err := handler.provider.loadRoutesImpl()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -97,7 +97,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *EventHandler) Add(parent *task.Task, route *route.Route) {
|
func (handler *EventHandler) Add(parent task.Parent, route *route.Route) {
|
||||||
err := handler.provider.startRoute(parent, route)
|
err := handler.provider.startRoute(parent, route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handler.errs.Add(err.Subject("add"))
|
handler.errs.Add(err.Subject("add"))
|
||||||
|
@ -112,7 +112,7 @@ func (handler *EventHandler) Remove(route *route.Route) {
|
||||||
handler.removed.Adds(route.Entry.Alias)
|
handler.removed.Adds(route.Entry.Alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *EventHandler) Update(parent *task.Task, oldRoute *route.Route, newRoute *route.Route) {
|
func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) {
|
||||||
oldRoute.Finish("route update")
|
oldRoute.Finish("route update")
|
||||||
err := handler.provider.startRoute(parent, newRoute)
|
err := handler.provider.startRoute(parent, newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -60,7 +61,7 @@ func NewFileProvider(filename string) (p *Provider, err error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, ErrEmptyProviderName
|
return nil, ErrEmptyProviderName
|
||||||
}
|
}
|
||||||
p = newProvider(name, ProviderTypeFile)
|
p = newProvider(strings.ReplaceAll(name, ".", "_"), ProviderTypeFile)
|
||||||
p.ProviderImpl, err = FileProviderImpl(filename)
|
p.ProviderImpl, err = FileProviderImpl(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -100,46 +101,43 @@ func (p *Provider) MarshalText() ([]byte, error) {
|
||||||
return []byte(p.String()), nil
|
return []byte(p.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) startRoute(parent *task.Task, r *R.Route) E.Error {
|
func (p *Provider) startRoute(parent task.Parent, r *R.Route) E.Error {
|
||||||
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
|
err := r.Start(parent)
|
||||||
err := r.Start(subtask)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.routes.Delete(r.Entry.Alias)
|
|
||||||
subtask.Finish(err) // just to ensure
|
|
||||||
return err.Subject(r.Entry.Alias)
|
return err.Subject(r.Entry.Alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.routes.Store(r.Entry.Alias, r)
|
p.routes.Store(r.Entry.Alias, r)
|
||||||
subtask.OnFinished("del from provider", func() {
|
r.Task().OnFinished("provider_remove_route", func() {
|
||||||
p.routes.Delete(r.Entry.Alias)
|
p.routes.Delete(r.Entry.Alias)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements*task.TaskStarter.
|
// Start implements*task.TaskStarter.
|
||||||
func (p *Provider) Start(configSubtask *task.Task) E.Error {
|
func (p *Provider) Start(parent task.Parent) E.Error {
|
||||||
// routes and event queue will stop on parent cancel
|
t := parent.Subtask("provider."+p.name, false)
|
||||||
providerTask := configSubtask
|
|
||||||
|
|
||||||
|
// routes and event queue will stop on config reload
|
||||||
errs := p.routes.CollectErrorsParallel(
|
errs := p.routes.CollectErrorsParallel(
|
||||||
func(alias string, r *R.Route) error {
|
func(alias string, r *R.Route) error {
|
||||||
return p.startRoute(providerTask, r)
|
return p.startRoute(t, r)
|
||||||
})
|
})
|
||||||
|
|
||||||
eventQueue := events.NewEventQueue(
|
eventQueue := events.NewEventQueue(
|
||||||
providerTask,
|
t.Subtask("event_queue", false),
|
||||||
providerEventFlushInterval,
|
providerEventFlushInterval,
|
||||||
func(flushTask *task.Task, events []events.Event) {
|
func(events []events.Event) {
|
||||||
handler := p.newEventHandler()
|
handler := p.newEventHandler()
|
||||||
// routes' lifetime should follow the provider's lifetime
|
// routes' lifetime should follow the provider's lifetime
|
||||||
handler.Handle(providerTask, events)
|
handler.Handle(t, events)
|
||||||
handler.Log()
|
handler.Log()
|
||||||
flushTask.Finish("events flushed")
|
|
||||||
},
|
},
|
||||||
func(err E.Error) {
|
func(err E.Error) {
|
||||||
E.LogError("event error", err, p.Logger())
|
E.LogError("event error", err, p.Logger())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eventQueue.Start(p.watcher.Events(providerTask.Context()))
|
eventQueue.Start(p.watcher.Events(t.Context()))
|
||||||
|
|
||||||
if err := E.Join(errs...); err != nil {
|
if err := E.Join(errs...); err != nil {
|
||||||
return err.Subject(p.String())
|
return err.Subject(p.String())
|
||||||
|
|
|
@ -47,20 +47,22 @@ func (r *StreamRoute) String() string {
|
||||||
return "stream " + r.TargetName()
|
return "stream " + r.TargetName()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements*task.TaskStarter.
|
// Start implements task.TaskStarter.
|
||||||
func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
func (r *StreamRoute) Start(parent task.Parent) E.Error {
|
||||||
if entry.ShouldNotServe(r) {
|
if entry.ShouldNotServe(r) {
|
||||||
providerSubtask.Finish("should not serve")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r.task = providerSubtask
|
r.task = parent.Subtask("stream." + r.TargetName())
|
||||||
r.Stream = NewStream(r)
|
r.Stream = NewStream(r)
|
||||||
|
|
||||||
|
parent.OnCancel("finish", func() {
|
||||||
|
r.task.Finish(nil)
|
||||||
|
})
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case entry.UseIdleWatcher(r):
|
case entry.UseIdleWatcher(r):
|
||||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + r.TargetName())
|
waker, err := idlewatcher.NewStreamWaker(r.task, r.StreamEntry, r.Stream)
|
||||||
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.task.Finish(err)
|
r.task.Finish(err)
|
||||||
return err
|
return err
|
||||||
|
@ -73,7 +75,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck)
|
fallback := monitor.NewRawHealthChecker(r.TargetURL(), r.Raw.HealthCheck)
|
||||||
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
|
r.HealthMon = monitor.NewDockerHealthMonitor(client, r.Idlewatcher.ContainerID, r.TargetName(), r.Raw.HealthCheck, fallback)
|
||||||
r.task.OnCancel("close docker client", client.Close)
|
r.task.OnCancel("close_docker_client", client.Close)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if r.HealthMon == nil {
|
if r.HealthMon == nil {
|
||||||
|
@ -86,7 +88,7 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
||||||
return E.From(err)
|
return E.From(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.task.OnFinished("close stream", func() {
|
r.task.OnFinished("close_stream", func() {
|
||||||
if err := r.Stream.Close(); err != nil {
|
if err := r.Stream.Close(); err != nil {
|
||||||
E.LogError("close stream failed", err, &r.l)
|
E.LogError("close stream failed", err, &r.l)
|
||||||
}
|
}
|
||||||
|
@ -97,22 +99,26 @@ func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error {
|
||||||
Msg("listening")
|
Msg("listening")
|
||||||
|
|
||||||
if r.HealthMon != nil {
|
if r.HealthMon != nil {
|
||||||
healthMonTask := r.task.Subtask("health monitor")
|
if err := r.HealthMon.Start(r.task); err != nil {
|
||||||
if err := r.HealthMon.Start(healthMonTask); err != nil {
|
|
||||||
E.LogWarn("health monitor error", err, &r.l)
|
E.LogWarn("health monitor error", err, &r.l)
|
||||||
healthMonTask.Finish(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go r.acceptConnections()
|
go r.acceptConnections()
|
||||||
|
|
||||||
routes.SetStreamRoute(r.TargetName(), r)
|
routes.SetStreamRoute(r.TargetName(), r)
|
||||||
r.task.OnFinished("remove from route table", func() {
|
r.task.OnFinished("entrypoint_remove_route", func() {
|
||||||
routes.DeleteStreamRoute(r.TargetName())
|
routes.DeleteStreamRoute(r.TargetName())
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Task implements task.TaskStarter.
|
||||||
|
func (r *StreamRoute) Task() *task.Task {
|
||||||
|
return r.task
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish implements task.TaskFinisher.
|
||||||
func (r *StreamRoute) Finish(reason any) {
|
func (r *StreamRoute) Finish(reason any) {
|
||||||
r.task.Finish(reason)
|
r.task.Finish(reason)
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,7 +95,7 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
|
||||||
return fmt.Errorf("unexpected listener type: %T", stream)
|
return fmt.Errorf("unexpected listener type: %T", stream)
|
||||||
}
|
}
|
||||||
case io.ReadWriteCloser:
|
case io.ReadWriteCloser:
|
||||||
stream.task.OnCancel("close conn", func() { conn.Close() })
|
stream.task.OnCancel("close_conn", func() { conn.Close() })
|
||||||
|
|
||||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||||
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
|
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
|
||||||
|
|
|
@ -4,355 +4,194 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"runtime/debug"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var globalTask = createGlobalTask()
|
|
||||||
|
|
||||||
func createGlobalTask() (t *Task) {
|
|
||||||
t = new(Task)
|
|
||||||
t.name = "root"
|
|
||||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
|
||||||
t.subtasks = F.NewSet[*Task]()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func testResetGlobalTask() {
|
|
||||||
globalTask = createGlobalTask()
|
|
||||||
}
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
TaskStarter interface {
|
TaskStarter interface {
|
||||||
// Start starts the object that implements TaskStarter,
|
// Start starts the object that implements TaskStarter,
|
||||||
// and returns an error if it fails to start.
|
// and returns an error if it fails to start.
|
||||||
//
|
//
|
||||||
// The task passed must be a subtask of the caller task.
|
|
||||||
//
|
|
||||||
// callerSubtask.Finish must be called when start fails or the object is finished.
|
// callerSubtask.Finish must be called when start fails or the object is finished.
|
||||||
Start(callerSubtask *Task) E.Error
|
Start(parent Parent) E.Error
|
||||||
|
Task() *Task
|
||||||
}
|
}
|
||||||
TaskFinisher interface {
|
TaskFinisher interface {
|
||||||
// Finish marks the task as finished and cancel its context.
|
|
||||||
//
|
|
||||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
|
||||||
// of the task to finish.
|
|
||||||
//
|
|
||||||
// Note that it will also cancel all subtasks.
|
|
||||||
Finish(reason any)
|
Finish(reason any)
|
||||||
}
|
}
|
||||||
// Task controls objects' lifetime.
|
// Task controls objects' lifetime.
|
||||||
//
|
//
|
||||||
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
|
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
|
||||||
//
|
//
|
||||||
// When passing a Task object to another function,
|
|
||||||
// it must be a sub-Task of the current Task,
|
|
||||||
// in name of "`currentTaskName`Subtask"
|
|
||||||
//
|
|
||||||
// Use Task.Finish to stop all subtasks of the Task.
|
// Use Task.Finish to stop all subtasks of the Task.
|
||||||
Task struct {
|
Task struct {
|
||||||
|
name string
|
||||||
|
|
||||||
|
children sync.WaitGroup
|
||||||
|
|
||||||
|
onFinished sync.WaitGroup
|
||||||
|
finished chan struct{}
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelCauseFunc
|
cancel context.CancelCauseFunc
|
||||||
|
|
||||||
parent *Task
|
once sync.Once
|
||||||
subtasks F.Set[*Task]
|
}
|
||||||
subTasksWg sync.WaitGroup
|
Parent interface {
|
||||||
|
Context() context.Context
|
||||||
name string
|
Subtask(name string, needFinish ...bool) *Task
|
||||||
|
Name() string
|
||||||
OnFinishedFuncs []func()
|
Finish(reason any)
|
||||||
OnFinishedMu sync.Mutex
|
OnCancel(name string, f func())
|
||||||
onFinishedWg sync.WaitGroup
|
|
||||||
|
|
||||||
finishOnce sync.Once
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrProgramExiting = errors.New("program exiting")
|
|
||||||
ErrTaskCanceled = errors.New("task canceled")
|
|
||||||
|
|
||||||
logger = logging.With().Str("module", "task").Logger()
|
|
||||||
)
|
|
||||||
|
|
||||||
// GlobalTask returns a new Task with the given name, derived from the global context.
|
|
||||||
func GlobalTask(format string, args ...any) *Task {
|
|
||||||
if len(args) > 0 {
|
|
||||||
format = fmt.Sprintf(format, args...)
|
|
||||||
}
|
|
||||||
return globalTask.Subtask(format)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugTaskMap returns a map[string]any representation of the global task tree.
|
|
||||||
//
|
|
||||||
// The returned map is suitable for encoding to JSON, and can be used
|
|
||||||
// to debug the task tree.
|
|
||||||
//
|
|
||||||
// The returned map is not guaranteed to be stable, and may change
|
|
||||||
// between runs of the program. It is intended for debugging purposes
|
|
||||||
// only.
|
|
||||||
func DebugTaskMap() map[string]any {
|
|
||||||
return globalTask.serialize()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CancelGlobalContext cancels the global task context, which will cause all tasks
|
|
||||||
// created to be canceled. This should be called before exiting the program
|
|
||||||
// to ensure that all tasks are properly cleaned up.
|
|
||||||
func CancelGlobalContext() {
|
|
||||||
globalTask.cancel(ErrProgramExiting)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
|
|
||||||
//
|
|
||||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
|
||||||
// still running when the timeout was reached, and their current tree
|
|
||||||
// of subtasks.
|
|
||||||
func GlobalContextWait(timeout time.Duration) (err error) {
|
|
||||||
done := make(chan struct{})
|
|
||||||
after := time.After(timeout)
|
|
||||||
go func() {
|
|
||||||
globalTask.Wait()
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
case <-after:
|
|
||||||
logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
|
|
||||||
return context.DeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) trace(msg string) {
|
|
||||||
logger.Trace().Str("name", t.name).Msg(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns the name of the task.
|
|
||||||
func (t *Task) Name() string {
|
|
||||||
if !common.IsTrace {
|
|
||||||
return t.name
|
|
||||||
}
|
|
||||||
parts := strings.Split(t.name, " > ")
|
|
||||||
return parts[len(parts)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the name of the task.
|
|
||||||
func (t *Task) String() string {
|
|
||||||
return t.name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context returns the context associated with the task. This context is
|
|
||||||
// canceled when Finish of the task is called, or parent task is canceled.
|
|
||||||
func (t *Task) Context() context.Context {
|
func (t *Task) Context() context.Context {
|
||||||
return t.ctx
|
return t.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Task) Finished() <-chan struct{} {
|
||||||
|
return t.finished
|
||||||
|
}
|
||||||
|
|
||||||
// FinishCause returns the reason / error that caused the task to be finished.
|
// FinishCause returns the reason / error that caused the task to be finished.
|
||||||
func (t *Task) FinishCause() error {
|
func (t *Task) FinishCause() error {
|
||||||
cause := context.Cause(t.ctx)
|
return context.Cause(t.ctx)
|
||||||
if cause == nil {
|
|
||||||
return t.ctx.Err()
|
|
||||||
}
|
|
||||||
return cause
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parent returns the parent task of the current task.
|
// OnFinished calls fn when the task is canceled and all subtasks are finished.
|
||||||
func (t *Task) Parent() *Task {
|
|
||||||
return t.parent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) runAllOnFinished(onCompTask *Task) {
|
|
||||||
<-t.ctx.Done()
|
|
||||||
t.WaitSubTasks()
|
|
||||||
for _, OnFinishedFunc := range t.OnFinishedFuncs {
|
|
||||||
OnFinishedFunc()
|
|
||||||
t.onFinishedWg.Done()
|
|
||||||
}
|
|
||||||
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnFinished calls fn when all subtasks are finished.
|
|
||||||
//
|
//
|
||||||
// It cannot be called after Finish or Wait is called.
|
// It should not be called after Finish is called.
|
||||||
func (t *Task) OnFinished(about string, fn func()) {
|
func (t *Task) OnFinished(about string, fn func()) {
|
||||||
if t.parent == globalTask {
|
t.onCancel(about, fn, true)
|
||||||
t.OnCancel(about, fn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.onFinishedWg.Add(1)
|
|
||||||
t.OnFinishedMu.Lock()
|
|
||||||
defer t.OnFinishedMu.Unlock()
|
|
||||||
|
|
||||||
if t.OnFinishedFuncs == nil {
|
|
||||||
onCompTask := GlobalTask(t.name + " > OnFinished > " + about)
|
|
||||||
go t.runAllOnFinished(onCompTask)
|
|
||||||
}
|
|
||||||
idx := len(t.OnFinishedFuncs)
|
|
||||||
wrapped := func() {
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
logger.Error().
|
|
||||||
Str("name", t.name).
|
|
||||||
Interface("err", err).
|
|
||||||
Msg("panic in " + about)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
fn()
|
|
||||||
logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about)
|
|
||||||
}
|
|
||||||
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnCancel calls fn when the task is canceled.
|
// OnCancel calls fn when the task is canceled.
|
||||||
//
|
//
|
||||||
// It cannot be called after Finish or Wait is called.
|
// It should not be called after Finish is called.
|
||||||
func (t *Task) OnCancel(about string, fn func()) {
|
func (t *Task) OnCancel(about string, fn func()) {
|
||||||
onCompTask := GlobalTask(t.name + " > OnFinished")
|
t.onCancel(about, fn, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
|
||||||
|
t.onFinished.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
<-t.ctx.Done()
|
<-t.ctx.Done()
|
||||||
fn()
|
if waitSubTasks {
|
||||||
onCompTask.Finish("done")
|
t.children.Wait()
|
||||||
t.trace("onCancel done: " + about)
|
}
|
||||||
|
t.invokeWithRecover(fn, about)
|
||||||
|
t.onFinished.Done()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish marks the task as finished and cancel its context.
|
// Finish cancel all subtasks and wait for them to finish,
|
||||||
//
|
// then marks the task as finished, with the given reason (if any).
|
||||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
|
||||||
// of the task to finish.
|
|
||||||
//
|
|
||||||
// Note that it will also cancel all subtasks.
|
|
||||||
func (t *Task) Finish(reason any) {
|
func (t *Task) Finish(reason any) {
|
||||||
var format string
|
select {
|
||||||
switch reason.(type) {
|
case <-t.finished:
|
||||||
case error:
|
return
|
||||||
format = "%w"
|
|
||||||
case string, fmt.Stringer:
|
|
||||||
format = "%s"
|
|
||||||
default:
|
default:
|
||||||
format = "%v"
|
t.once.Do(func() {
|
||||||
|
t.finish(reason)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) finish(reason any) {
|
||||||
|
t.cancel(fmtCause(reason))
|
||||||
|
t.children.Wait()
|
||||||
|
t.onFinished.Wait()
|
||||||
|
if t.finished != nil {
|
||||||
|
close(t.finished)
|
||||||
|
}
|
||||||
|
logger.Trace().Msg("task " + t.name + " finished")
|
||||||
|
}
|
||||||
|
|
||||||
|
func fmtCause(cause any) error {
|
||||||
|
switch cause := cause.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case error:
|
||||||
|
return cause
|
||||||
|
case string:
|
||||||
|
return errors.New(cause)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%v", cause)
|
||||||
}
|
}
|
||||||
t.finishOnce.Do(func() {
|
|
||||||
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
|
|
||||||
})
|
|
||||||
t.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||||
//
|
//
|
||||||
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
|
// This should not be called after Finish is called.
|
||||||
//
|
func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||||
// This should not be called after Finish, Wait, or WaitSubTasks is called.
|
nf := len(needFinish) == 0 || needFinish[0]
|
||||||
func (t *Task) Subtask(name string) *Task {
|
|
||||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
|
||||||
return t.newSubTask(ctx, cancel, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task {
|
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||||
parent := t
|
child := &Task{
|
||||||
if common.IsTrace {
|
finished: make(chan struct{}, 1),
|
||||||
name = parent.name + " > " + name
|
|
||||||
}
|
|
||||||
subtask := &Task{
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
name: name,
|
|
||||||
parent: parent,
|
|
||||||
subtasks: F.NewSet[*Task](),
|
|
||||||
}
|
}
|
||||||
parent.subTasksWg.Add(1)
|
if t != root {
|
||||||
parent.subtasks.Add(subtask)
|
child.name = t.name + "." + name
|
||||||
if common.IsTrace {
|
allTasks.Add(child)
|
||||||
subtask.trace("started")
|
} else {
|
||||||
|
child.name = name
|
||||||
|
}
|
||||||
|
|
||||||
|
allTasksWg.Add(1)
|
||||||
|
t.children.Add(1)
|
||||||
|
|
||||||
|
if !nf {
|
||||||
go func() {
|
go func() {
|
||||||
subtask.Wait()
|
<-child.ctx.Done()
|
||||||
subtask.trace("finished: " + subtask.FinishCause().Error())
|
child.Finish(nil)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
subtask.Wait()
|
<-child.finished
|
||||||
parent.subtasks.Remove(subtask)
|
allTasksWg.Done()
|
||||||
parent.subTasksWg.Done()
|
t.children.Done()
|
||||||
|
allTasks.Remove(child)
|
||||||
}()
|
}()
|
||||||
return subtask
|
|
||||||
|
logger.Trace().Msg("task " + child.name + " started")
|
||||||
|
return child
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
|
// Name returns the name of the task without parent names.
|
||||||
//
|
func (t *Task) Name() string {
|
||||||
// It must be called only after Finish is called.
|
parts := strutils.SplitRune(t.name, '.')
|
||||||
func (t *Task) Wait() {
|
return parts[len(parts)-1]
|
||||||
<-t.ctx.Done()
|
|
||||||
t.WaitSubTasks()
|
|
||||||
t.onFinishedWg.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitSubTasks waits for all subtasks of the task to finish.
|
// String returns the full name of the task.
|
||||||
//
|
func (t *Task) String() string {
|
||||||
// No more subtasks can be added after this call.
|
return t.name
|
||||||
//
|
|
||||||
// It can be called before Finish is called.
|
|
||||||
func (t *Task) WaitSubTasks() {
|
|
||||||
t.subTasksWg.Wait()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// tree returns a string representation of the task tree, with the given
|
// MarshalText implements encoding.TextMarshaler.
|
||||||
// prefix prepended to each line. The prefix is used to indent the tree,
|
func (t *Task) MarshalText() ([]byte, error) {
|
||||||
// and should be a string of spaces or a similar separator.
|
return []byte(t.name), nil
|
||||||
//
|
|
||||||
// The resulting string is suitable for printing to the console, and can be
|
|
||||||
// used to debug the task tree.
|
|
||||||
//
|
|
||||||
// The tree is traversed in a depth-first manner, with each task's name and
|
|
||||||
// line number (if available) printed on a separate line. The line number is
|
|
||||||
// only printed if the task was created with a non-empty line argument.
|
|
||||||
//
|
|
||||||
// The returned string is not guaranteed to be stable, and may change between
|
|
||||||
// runs of the program. It is intended for debugging purposes only.
|
|
||||||
func (t *Task) tree(prefix ...string) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
var pre string
|
|
||||||
if len(prefix) > 0 {
|
|
||||||
pre = prefix[0]
|
|
||||||
sb.WriteString(pre + "- ")
|
|
||||||
}
|
|
||||||
sb.WriteString(t.Name() + "\n")
|
|
||||||
t.subtasks.RangeAll(func(subtask *Task) {
|
|
||||||
sb.WriteString(subtask.tree(pre + " "))
|
|
||||||
})
|
|
||||||
return sb.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// serialize returns a map[string]any representation of the task tree.
|
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
||||||
//
|
defer func() {
|
||||||
// The map contains the following keys:
|
if err := recover(); err != nil {
|
||||||
// - name: the name of the task
|
logger.Error().
|
||||||
// - subtasks: a slice of maps, each representing a subtask
|
Interface("err", err).
|
||||||
//
|
Msg("panic in task " + t.name + "." + caller)
|
||||||
// The subtask maps contain the same keys, recursively.
|
if common.IsDebug {
|
||||||
//
|
panic(string(debug.Stack()))
|
||||||
// The returned map is suitable for encoding to JSON, and can be used
|
}
|
||||||
// to debug the task tree.
|
}
|
||||||
//
|
}()
|
||||||
// The returned map is not guaranteed to be stable, and may change
|
fn()
|
||||||
// between runs of the program. It is intended for debugging purposes
|
|
||||||
// only.
|
|
||||||
func (t *Task) serialize() map[string]any {
|
|
||||||
m := make(map[string]any)
|
|
||||||
parts := strings.Split(t.name, " > ")
|
|
||||||
m["name"] = parts[len(parts)-1]
|
|
||||||
if t.subtasks.Size() > 0 {
|
|
||||||
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
|
|
||||||
t.subtasks.RangeAll(func(subtask *Task) {
|
|
||||||
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,132 +2,112 @@ package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync/atomic"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
func testTask() *Task {
|
||||||
rootTaskName = "root-task"
|
return RootTask("test", false)
|
||||||
subTaskName = "subtask"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestTaskCreation(t *testing.T) {
|
|
||||||
rootTask := GlobalTask(rootTaskName)
|
|
||||||
subTask := rootTask.Subtask(subTaskName)
|
|
||||||
|
|
||||||
ExpectEqual(t, rootTaskName, rootTask.Name())
|
|
||||||
ExpectEqual(t, subTaskName, subTask.Name())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTaskCancellation(t *testing.T) {
|
func TestChildTaskCancellation(t *testing.T) {
|
||||||
subTaskDone := make(chan struct{})
|
t.Cleanup(testCleanup)
|
||||||
|
|
||||||
rootTask := GlobalTask(rootTaskName)
|
parent := testTask()
|
||||||
subTask := rootTask.Subtask(subTaskName)
|
child := parent.Subtask("")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
subTask.Wait()
|
defer child.Finish(nil)
|
||||||
close(subTaskDone)
|
for {
|
||||||
|
select {
|
||||||
|
case <-child.Context().Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go rootTask.Finish(nil)
|
parent.cancel(nil) // should also cancel child
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-subTaskDone:
|
case <-child.Finished():
|
||||||
err := subTask.Context().Err()
|
ExpectError(t, context.Canceled, child.Context().Err())
|
||||||
ExpectError(t, context.Canceled, err)
|
default:
|
||||||
cause := context.Cause(subTask.Context())
|
|
||||||
ExpectError(t, ErrTaskCanceled, cause)
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Fatal("subTask context was not canceled as expected")
|
t.Fatal("subTask context was not canceled as expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOnComplete(t *testing.T) {
|
func TestTaskOnCancelOnFinished(t *testing.T) {
|
||||||
rootTask := GlobalTask(rootTaskName)
|
t.Cleanup(testCleanup)
|
||||||
task := rootTask.Subtask(subTaskName)
|
task := testTask()
|
||||||
|
|
||||||
var value atomic.Int32
|
var shouldTrueOnCancel bool
|
||||||
task.OnFinished("set value", func() {
|
var shouldTrueOnFinish bool
|
||||||
value.Store(1234)
|
|
||||||
|
task.OnCancel("", func() {
|
||||||
|
shouldTrueOnCancel = true
|
||||||
})
|
})
|
||||||
|
task.OnFinished("", func() {
|
||||||
|
shouldTrueOnFinish = true
|
||||||
|
})
|
||||||
|
|
||||||
|
ExpectFalse(t, shouldTrueOnFinish)
|
||||||
task.Finish(nil)
|
task.Finish(nil)
|
||||||
ExpectEqual(t, value.Load(), 1234)
|
ExpectTrue(t, shouldTrueOnCancel)
|
||||||
|
ExpectTrue(t, shouldTrueOnFinish)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGlobalContextWait(t *testing.T) {
|
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
||||||
testResetGlobalTask()
|
t.Cleanup(testCleanup)
|
||||||
defer CancelGlobalContext()
|
task := testTask()
|
||||||
|
|
||||||
rootTask := GlobalTask(rootTaskName)
|
finished := false
|
||||||
|
|
||||||
finished1, finished2 := false, false
|
task.OnFinished("", func() {
|
||||||
|
finished = true
|
||||||
subTask1 := rootTask.Subtask(subTaskName)
|
|
||||||
subTask2 := rootTask.Subtask(subTaskName)
|
|
||||||
subTask1.OnFinished("", func() {
|
|
||||||
finished1 = true
|
|
||||||
})
|
|
||||||
subTask2.OnFinished("", func() {
|
|
||||||
finished2 = true
|
|
||||||
})
|
})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(500 * time.Millisecond)
|
defer task.Finish(nil)
|
||||||
subTask1.Finish(nil)
|
for {
|
||||||
|
select {
|
||||||
|
case <-task.Context().Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
ExpectNoError(t, GracefulShutdown(1*time.Second))
|
||||||
time.Sleep(500 * time.Millisecond)
|
ExpectTrue(t, finished)
|
||||||
subTask2.Finish(nil)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
<-root.finished
|
||||||
subTask1.Wait()
|
ExpectError(t, context.Canceled, task.Context().Err())
|
||||||
subTask2.Wait()
|
ExpectError(t, ErrProgramExiting, task.FinishCause())
|
||||||
rootTask.Finish(nil)
|
|
||||||
}()
|
|
||||||
|
|
||||||
_ = GlobalContextWait(1 * time.Second)
|
|
||||||
ExpectTrue(t, finished1)
|
|
||||||
ExpectTrue(t, finished2)
|
|
||||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
|
||||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
|
|
||||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
func TestTimeoutOnGracefulShutdown(t *testing.T) {
|
||||||
testResetGlobalTask()
|
t.Cleanup(testCleanup)
|
||||||
|
_ = testTask()
|
||||||
|
|
||||||
rootTask := GlobalTask(rootTaskName)
|
ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
|
||||||
rootTask.Subtask(subTaskName)
|
|
||||||
|
|
||||||
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGlobalContextCancellation(t *testing.T) {
|
func TestFinishMultipleCalls(t *testing.T) {
|
||||||
testResetGlobalTask()
|
t.Cleanup(testCleanup)
|
||||||
|
task := testTask()
|
||||||
taskDone := make(chan struct{})
|
var wg sync.WaitGroup
|
||||||
rootTask := GlobalTask(rootTaskName)
|
wg.Add(5)
|
||||||
|
for range 5 {
|
||||||
go func() {
|
go func() {
|
||||||
rootTask.Wait()
|
defer wg.Done()
|
||||||
close(taskDone)
|
task.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
CancelGlobalContext()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-taskDone:
|
|
||||||
err := rootTask.Context().Err()
|
|
||||||
ExpectError(t, context.Canceled, err)
|
|
||||||
cause := context.Cause(rootTask.Context())
|
|
||||||
ExpectError(t, ErrProgramExiting, cause)
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Fatal("subTask context was not canceled as expected")
|
|
||||||
}
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
96
internal/task/utils.go
Normal file
96
internal/task/utils.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrProgramExiting = errors.New("program exiting")
|
||||||
|
|
||||||
|
var logger = logging.With().Str("module", "task").Logger()
|
||||||
|
|
||||||
|
var root = newRoot()
|
||||||
|
var allTasks = F.NewSet[*Task]()
|
||||||
|
var allTasksWg sync.WaitGroup
|
||||||
|
|
||||||
|
func testCleanup() {
|
||||||
|
root = newRoot()
|
||||||
|
allTasks.Clear()
|
||||||
|
allTasksWg = sync.WaitGroup{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RootTask returns a new Task with the given name, derived from the root context.
|
||||||
|
func RootTask(name string, needFinish bool) *Task {
|
||||||
|
return root.Subtask(name, needFinish)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRoot() *Task {
|
||||||
|
t := &Task{name: "root"}
|
||||||
|
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func RootContext() context.Context {
|
||||||
|
return root.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func RootContextCanceled() <-chan struct{} {
|
||||||
|
return root.ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func OnProgramExit(about string, fn func()) {
|
||||||
|
root.OnFinished(about, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
|
||||||
|
//
|
||||||
|
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||||
|
// still running when the timeout was reached, and their current tree
|
||||||
|
// of subtasks.
|
||||||
|
func GracefulShutdown(timeout time.Duration) (err error) {
|
||||||
|
root.cancel(ErrProgramExiting)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
after := time.After(timeout)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
allTasksWg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case <-after:
|
||||||
|
b, err := json.Marshal(DebugTaskList())
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn().Err(err).Msg("failed to marshal tasks")
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DebugTaskList returns list of all tasks.
|
||||||
|
//
|
||||||
|
// The returned string is suitable for printing to the console.
|
||||||
|
func DebugTaskList() []string {
|
||||||
|
l := make([]string, 0, allTasks.Size())
|
||||||
|
|
||||||
|
allTasks.RangeAll(func(t *Task) {
|
||||||
|
l = append(l, t.name)
|
||||||
|
})
|
||||||
|
|
||||||
|
slices.Sort(l)
|
||||||
|
return l
|
||||||
|
}
|
|
@ -1,30 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AtomicValue[T any] struct {
|
|
||||||
atomic.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicValue[T]) Load() T {
|
|
||||||
return a.Value.Load().(T)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicValue[T]) Store(v T) {
|
|
||||||
a.Value.Store(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicValue[T]) Swap(v T) T {
|
|
||||||
return a.Value.Swap(v).(T)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicValue[T]) CompareAndSwap(oldV, newV T) bool {
|
|
||||||
return a.Value.CompareAndSwap(oldV, newV)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicValue[T]) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(a.Load())
|
|
||||||
}
|
|
30
internal/utils/atomic/atomic_value.go
Normal file
30
internal/utils/atomic/atomic_value.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package atomic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Value[T any] struct {
|
||||||
|
atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Value[T]) Load() T {
|
||||||
|
return a.Value.Load().(T)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Value[T]) Store(v T) {
|
||||||
|
a.Value.Store(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Value[T]) Swap(v T) T {
|
||||||
|
return a.Value.Swap(v).(T)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Value[T]) CompareAndSwap(oldV, newV T) bool {
|
||||||
|
return a.Value.CompareAndSwap(oldV, newV)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Value[T]) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(a.Load())
|
||||||
|
}
|
|
@ -152,9 +152,10 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
|
||||||
return m.CollectErrors(do)
|
return m.CollectErrors(do)
|
||||||
}
|
}
|
||||||
|
|
||||||
errs := make([]error, 0)
|
var errs []error
|
||||||
mu := sync.Mutex{}
|
var mu sync.Mutex
|
||||||
wg := sync.WaitGroup{}
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
m.Range(func(k KT, v VT) bool {
|
m.Range(func(k KT, v VT) bool {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -171,24 +172,6 @@ func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error {
|
||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveAll removes all key-value pairs from the map where the value matches the given criteria.
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
//
|
|
||||||
// criteria: function to determine whether a value should be removed
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
//
|
|
||||||
// nothing
|
|
||||||
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
|
|
||||||
m.Range(func(k KT, v VT) bool {
|
|
||||||
if criteria(v) {
|
|
||||||
m.Delete(k)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Map[KT, VT]) Has(k KT) bool {
|
func (m Map[KT, VT]) Has(k KT) bool {
|
||||||
_, ok := m.Load(k)
|
_, ok := m.Load(k)
|
||||||
return ok
|
return ok
|
||||||
|
|
|
@ -61,3 +61,7 @@ func (set Set[T]) RangeAllParallel(f func(T)) {
|
||||||
func (set Set[T]) Size() int {
|
func (set Set[T]) Size() int {
|
||||||
return set.m.Size()
|
return set.m.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (set Set[T]) IsEmpty() bool {
|
||||||
|
return set.m == nil || set.m.Size() == 0
|
||||||
|
}
|
||||||
|
|
|
@ -20,10 +20,17 @@ func IgnoreError[Result any](r Result, _ error) Result {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fmtError(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return "<nil>"
|
||||||
|
}
|
||||||
|
return ansi.StripANSI(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
func ExpectNoError(t *testing.T, err error) {
|
func ExpectNoError(t *testing.T, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err != nil && !reflect.ValueOf(err).IsNil() {
|
if err != nil {
|
||||||
t.Errorf("expected err=nil, got %s", ansi.StripANSI(err.Error()))
|
t.Errorf("expected err=nil, got %s", fmtError(err))
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +38,7 @@ func ExpectNoError(t *testing.T, err error) {
|
||||||
func ExpectError(t *testing.T, expected error, err error) {
|
func ExpectError(t *testing.T, expected error, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Errorf("expected err %s, got %s", expected, ansi.StripANSI(err.Error()))
|
t.Errorf("expected err %s, got %s", expected, fmtError(err))
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,7 +46,7 @@ func ExpectError(t *testing.T, expected error, err error) {
|
||||||
func ExpectError2(t *testing.T, input any, expected error, err error) {
|
func ExpectError2(t *testing.T, input any, expected error, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Errorf("%v: expected err %s, got %s", input, expected, ansi.StripANSI(err.Error()))
|
t.Errorf("%v: expected err %s, got %s", input, expected, fmtError(err))
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -48,7 +55,7 @@ func ExpectErrorT[T error](t *testing.T, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
var errAs T
|
var errAs T
|
||||||
if !errors.As(err, &errAs) {
|
if !errors.As(err, &errAs) {
|
||||||
t.Errorf("expected err %T, got %s", errAs, ansi.StripANSI(err.Error()))
|
t.Errorf("expected err %T, got %s", errAs, fmtError(err))
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,10 @@ var (
|
||||||
func NewConfigFileWatcher(filename string) Watcher {
|
func NewConfigFileWatcher(filename string) Watcher {
|
||||||
configDirWatcherMu.Lock()
|
configDirWatcherMu.Lock()
|
||||||
defer configDirWatcherMu.Unlock()
|
defer configDirWatcherMu.Unlock()
|
||||||
|
|
||||||
if configDirWatcher == nil {
|
if configDirWatcher == nil {
|
||||||
configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath)
|
t := task.RootTask("config_dir_watcher", false)
|
||||||
|
configDirWatcher = NewDirectoryWatcher(t, common.ConfigBasePath)
|
||||||
}
|
}
|
||||||
return configDirWatcher.Add(filename)
|
return configDirWatcher.Add(filename)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ type DirWatcher struct {
|
||||||
//
|
//
|
||||||
// Note that the returned DirWatcher is not ready to use until the goroutine
|
// Note that the returned DirWatcher is not ready to use until the goroutine
|
||||||
// started by NewDirectoryWatcher has finished.
|
// started by NewDirectoryWatcher has finished.
|
||||||
func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
|
func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
|
||||||
//! subdirectories are not watched
|
//! subdirectories are not watched
|
||||||
w, err := fsnotify.NewWatcher()
|
w, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +56,7 @@ func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher {
|
||||||
fwMap: F.NewMapOf[string, *fileWatcher](),
|
fwMap: F.NewMapOf[string, *fileWatcher](),
|
||||||
eventCh: make(chan Event),
|
eventCh: make(chan Event),
|
||||||
errCh: make(chan E.Error),
|
errCh: make(chan E.Error),
|
||||||
task: callerSubtask,
|
task: parent.Subtask("dir_watcher(" + dirPath + ")"),
|
||||||
}
|
}
|
||||||
go helper.start()
|
go helper.start()
|
||||||
return helper
|
return helper
|
||||||
|
@ -80,17 +80,19 @@ func (h *DirWatcher) Add(relPath string) Watcher {
|
||||||
eventCh: make(chan Event),
|
eventCh: make(chan Event),
|
||||||
errCh: make(chan E.Error),
|
errCh: make(chan E.Error),
|
||||||
}
|
}
|
||||||
h.task.OnFinished("close file watcher for "+relPath, func() {
|
|
||||||
close(s.eventCh)
|
|
||||||
close(s.errCh)
|
|
||||||
})
|
|
||||||
h.fwMap.Store(relPath, s)
|
h.fwMap.Store(relPath, s)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *DirWatcher) cleanup() {
|
||||||
|
h.w.Close()
|
||||||
|
close(h.eventCh)
|
||||||
|
close(h.errCh)
|
||||||
|
h.task.Finish(nil)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *DirWatcher) start() {
|
func (h *DirWatcher) start() {
|
||||||
defer close(h.eventCh)
|
defer h.cleanup()
|
||||||
defer h.w.Close()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package events
|
package events
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"runtime/debug"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
@ -17,7 +18,7 @@ type (
|
||||||
onFlush OnFlushFunc
|
onFlush OnFlushFunc
|
||||||
onError OnErrorFunc
|
onError OnErrorFunc
|
||||||
}
|
}
|
||||||
OnFlushFunc = func(flushTask *task.Task, events []Event)
|
OnFlushFunc = func(events []Event)
|
||||||
OnErrorFunc = func(err E.Error)
|
OnErrorFunc = func(err E.Error)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,9 +39,9 @@ const eventQueueCapacity = 10
|
||||||
// but the onFlush function can return earlier (e.g. run in another goroutine).
|
// but the onFlush function can return earlier (e.g. run in another goroutine).
|
||||||
//
|
//
|
||||||
// If task is canceled before the flushInterval is reached, the events in queue will be discarded.
|
// If task is canceled before the flushInterval is reached, the events in queue will be discarded.
|
||||||
func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
|
func NewEventQueue(queueTask *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
|
||||||
return &EventQueue{
|
return &EventQueue{
|
||||||
task: parent.Subtask("event queue"),
|
task: queueTask,
|
||||||
queue: make([]Event, 0, eventQueueCapacity),
|
queue: make([]Event, 0, eventQueueCapacity),
|
||||||
ticker: time.NewTicker(flushInterval),
|
ticker: time.NewTicker(flushInterval),
|
||||||
flushInterval: flushInterval,
|
flushInterval: flushInterval,
|
||||||
|
@ -50,19 +51,20 @@ func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlu
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
||||||
if common.IsProduction {
|
origOnFlush := e.onFlush
|
||||||
origOnFlush := e.onFlush
|
// recover panic in onFlush when in production mode
|
||||||
// recover panic in onFlush when in production mode
|
e.onFlush = func(events []Event) {
|
||||||
e.onFlush = func(flushTask *task.Task, events []Event) {
|
defer func() {
|
||||||
defer func() {
|
if err := recover(); err != nil {
|
||||||
if err := recover(); err != nil {
|
e.onError(E.New("recovered panic in onFlush").
|
||||||
e.onError(E.New("recovered panic in onFlush").
|
Withf("%v", err).
|
||||||
Withf("%v", err).
|
Subject(e.task.Name()))
|
||||||
Subject(e.task.Parent().String()))
|
if common.IsDebug {
|
||||||
|
panic(string(debug.Stack()))
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
origOnFlush(flushTask, events)
|
}()
|
||||||
}
|
origOnFlush(events)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -75,19 +77,24 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
||||||
return
|
return
|
||||||
case <-e.ticker.C:
|
case <-e.ticker.C:
|
||||||
if len(e.queue) > 0 {
|
if len(e.queue) > 0 {
|
||||||
flushTask := e.task.Subtask("flush events")
|
// clone -> clear -> flush
|
||||||
queue := e.queue
|
queue := make([]Event, len(e.queue))
|
||||||
e.queue = make([]Event, 0, eventQueueCapacity)
|
copy(queue, e.queue)
|
||||||
go e.onFlush(flushTask, queue)
|
|
||||||
flushTask.Wait()
|
e.queue = e.queue[:0]
|
||||||
|
|
||||||
|
e.onFlush(queue)
|
||||||
}
|
}
|
||||||
e.ticker.Reset(e.flushInterval)
|
e.ticker.Reset(e.flushInterval)
|
||||||
case event, ok := <-eventCh:
|
case event, ok := <-eventCh:
|
||||||
e.queue = append(e.queue, event)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case err := <-errCh:
|
e.queue = append(e.queue, event)
|
||||||
|
case err, ok := <-errCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.onError(err)
|
e.onError(err)
|
||||||
}
|
}
|
||||||
|
@ -95,10 +102,3 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait waits for all events to be flushed and the task to finish.
|
|
||||||
//
|
|
||||||
// It is safe to call this method multiple times.
|
|
||||||
func (e *EventQueue) Wait() {
|
|
||||||
e.task.Wait()
|
|
||||||
}
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/notif"
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||||
)
|
)
|
||||||
|
@ -23,9 +23,9 @@ type (
|
||||||
monitor struct {
|
monitor struct {
|
||||||
service string
|
service string
|
||||||
config *health.HealthCheckConfig
|
config *health.HealthCheckConfig
|
||||||
url U.AtomicValue[types.URL]
|
url atomic.Value[types.URL]
|
||||||
|
|
||||||
status U.AtomicValue[health.Status]
|
status atomic.Value[health.Status]
|
||||||
lastResult *health.HealthCheckResult
|
lastResult *health.HealthCheckResult
|
||||||
lastSeen time.Time
|
lastSeen time.Time
|
||||||
|
|
||||||
|
@ -59,10 +59,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start implements task.TaskStarter.
|
// Start implements task.TaskStarter.
|
||||||
func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
func (mon *monitor) Start(parent task.Parent) E.Error {
|
||||||
mon.service = routeSubtask.Parent().Name()
|
|
||||||
mon.task = routeSubtask
|
|
||||||
|
|
||||||
if mon.config.Interval <= 0 {
|
if mon.config.Interval <= 0 {
|
||||||
return E.From(ErrNegativeInterval)
|
return E.From(ErrNegativeInterval)
|
||||||
}
|
}
|
||||||
|
@ -71,6 +68,9 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
||||||
mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service))
|
mon.metric = metrics.GetServiceMetrics().HealthStatus.With(metrics.HealthMetricLabels(mon.service))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mon.service = parent.Name()
|
||||||
|
mon.task = parent.Subtask("health_monitor")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
logger := logging.With().Str("name", mon.service).Logger()
|
logger := logging.With().Str("name", mon.service).Logger()
|
||||||
|
|
||||||
|
@ -78,10 +78,10 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
||||||
if mon.status.Load() != health.StatusError {
|
if mon.status.Load() != health.StatusError {
|
||||||
mon.status.Store(health.StatusUnknown)
|
mon.status.Store(health.StatusUnknown)
|
||||||
}
|
}
|
||||||
mon.task.Finish(nil)
|
|
||||||
if mon.metric != nil {
|
if mon.metric != nil {
|
||||||
mon.metric.Reset()
|
mon.metric.Reset()
|
||||||
}
|
}
|
||||||
|
mon.task.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := mon.checkUpdateHealth(); err != nil {
|
if err := mon.checkUpdateHealth(); err != nil {
|
||||||
|
@ -108,6 +108,11 @@ func (mon *monitor) Start(routeSubtask *task.Task) E.Error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Task implements task.TaskStarter.
|
||||||
|
func (mon *monitor) Task() *task.Task {
|
||||||
|
return mon.task
|
||||||
|
}
|
||||||
|
|
||||||
// Finish implements task.TaskFinisher.
|
// Finish implements task.TaskFinisher.
|
||||||
func (mon *monitor) Finish(reason any) {
|
func (mon *monitor) Finish(reason any) {
|
||||||
mon.task.Finish(reason)
|
mon.task.Finish(reason)
|
||||||
|
|
Loading…
Add table
Reference in a new issue