diff --git a/.gitignore b/.gitignore index f89bbdf..fc5d04b 100755 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ todo.md .aider* mtrace.json .env +test.Dockerfile diff --git a/Makefile b/Makefile index f37c7eb..6aa0a9e 100755 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ repush: git push gitlab dev --force rapid-crash: - sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\ + sudo docker run --restart=always --name test_crash -p 80 debian:bookworm-slim /bin/cat &&\ sleep 3 &&\ sudo docker rm -f test_crash @@ -64,4 +64,4 @@ debug-list-containers: ci-test: mkdir -p /tmp/artifacts - act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" \ No newline at end of file + act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" diff --git a/cmd/main.go b/cmd/main.go index 79b9a92..8da9f36 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -84,7 +84,7 @@ func main() { middleware.LoadComposeFiles() var cfg *config.Config - var err E.NestedError + var err E.Error if cfg, err = config.Load(); err != nil { logrus.Warn(err) } diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index c92d3ef..084b4f1 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -39,7 +39,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { return } - var validateErr E.NestedError + var validateErr E.Error if filename == common.ConfigFileName { validateErr = config.Validate(content) } else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) { diff --git a/internal/api/v1/query/query.go b/internal/api/v1/query/query.go index 88351f8..cd30cc7 100644 --- a/internal/api/v1/query/query.go +++ b/internal/api/v1/query/query.go @@ -13,7 +13,7 @@ import ( "github.com/yusing/go-proxy/internal/net/http/middleware" ) -func ReloadServer() E.NestedError { +func ReloadServer() E.Error { resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil) if err != nil { return E.From(err) @@ -34,7 +34,7 @@ func ReloadServer() E.NestedError { return nil } -func List[T any](what string) (_ T, outErr E.NestedError) { +func List[T any](what string) (_ T, outErr E.Error) { resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what)) if err != nil { outErr = E.From(err) @@ -54,14 +54,14 @@ func List[T any](what string) (_ T, outErr E.NestedError) { return res, nil } -func ListRoutes() (map[string]map[string]any, E.NestedError) { +func ListRoutes() (map[string]map[string]any, E.Error) { return List[map[string]map[string]any](v1.ListRoutes) } -func ListMiddlewareTraces() (middleware.Traces, E.NestedError) { +func ListMiddlewareTraces() (middleware.Traces, E.Error) { return List[middleware.Traces](v1.ListMiddlewareTraces) } -func DebugListTasks() (map[string]any, E.NestedError) { +func DebugListTasks() (map[string]any, E.Error) { return List[map[string]any](v1.ListTasks) } diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 2f0740d..351e479 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -27,7 +27,7 @@ func NewConfig(cfg *types.AutoCertConfig) *Config { return (*Config)(cfg) } -func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) { +func (cfg *Config) GetProvider() (provider *Provider, res E.Error) { b := E.NewBuilder("unable to initialize autocert") defer b.To(&res) diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 243bbb7..8ab2aad 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -29,7 +29,7 @@ type ( tlsCert *tls.Certificate certExpiries CertExpiries } - ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError) + ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error) CertExpiries map[string]time.Time ) @@ -57,7 +57,7 @@ func (p *Provider) GetExpiries() CertExpiries { return p.certExpiries } -func (p *Provider) ObtainCert() (res E.NestedError) { +func (p *Provider) ObtainCert() (res E.Error) { b := E.NewBuilder("failed to obtain certificate") defer b.To(&res) @@ -112,7 +112,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) { return nil } -func (p *Provider) LoadCert() E.NestedError { +func (p *Provider) LoadCert() E.Error { cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)) if err.HasError() { return err @@ -158,7 +158,7 @@ func (p *Provider) ScheduleRenewal() { }() } -func (p *Provider) initClient() E.NestedError { +func (p *Provider) initClient() E.Error { legoClient, err := E.Check(lego.NewClient(p.legoCfg)) if err.HasError() { return E.FailWith("create lego client", err) @@ -178,7 +178,7 @@ func (p *Provider) initClient() E.NestedError { return nil } -func (p *Provider) registerACME() E.NestedError { +func (p *Provider) registerACME() E.Error { if p.user.Registration != nil { return nil } @@ -191,7 +191,7 @@ func (p *Provider) registerACME() E.NestedError { return nil } -func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { +func (p *Provider) saveCert(cert *certificate.Resource) E.Error { /* This should have been done in setup but double check is always a good choice.*/ _, err := os.Stat(path.Dir(p.cfg.CertPath)) @@ -239,7 +239,7 @@ func (p *Provider) certState() CertState { return CertStateValid } -func (p *Provider) renewIfNeeded() E.NestedError { +func (p *Provider) renewIfNeeded() E.Error { if p.cfg.Provider == ProviderLocal { return nil } @@ -259,7 +259,7 @@ func (p *Provider) renewIfNeeded() E.NestedError { return nil } -func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { +func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) { r := make(CertExpiries, len(cert.Certificate)) for _, cert := range cert.Certificate { x509Cert, err := E.Check(x509.ParseCertificate(cert)) @@ -281,7 +281,7 @@ func providerGenerator[CT any, PT challenge.Provider]( defaultCfg func() *CT, newProvider func(*CT) (PT, error), ) ProviderGenerator { - return func(opt types.AutocertProviderOpt) (challenge.Provider, E.NestedError) { + return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) { cfg := defaultCfg() err := U.Deserialize(opt, cfg) if err.HasError() { diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 95d6089..62640b1 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -6,7 +6,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -func (p *Provider) Setup() (err E.NestedError) { +func (p *Provider) Setup() (err E.Error) { if err = p.LoadCert(); err != nil { if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist return err diff --git a/internal/common/constants.go b/internal/common/constants.go index 44c9102..e9ab187 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -47,3 +47,5 @@ const ( StopTimeoutDefault = "10s" StopMethodDefault = "stop" ) + +const HeaderCheckRedirect = "X-Goproxy-Check-Redirect" diff --git a/internal/config/config.go b/internal/config/config.go index 66645c6..0820c95 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -55,7 +55,7 @@ func newConfig() *Config { } } -func Load() (*Config, E.NestedError) { +func Load() (*Config, E.Error) { if instance != nil { return instance, nil } @@ -64,7 +64,7 @@ func Load() (*Config, E.NestedError) { return instance, instance.load() } -func Validate(data []byte) E.NestedError { +func Validate(data []byte) E.Error { return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data) } @@ -78,7 +78,7 @@ func WatchChanges() { task, configEventFlushInterval, OnConfigChange, - func(err E.NestedError) { + func(err E.Error) { logger.Error(err) }, ) @@ -104,7 +104,7 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) { } } -func Reload() E.NestedError { +func Reload() E.Error { // avoid race between config change and API reload request reloadMu.Lock() defer reloadMu.Unlock() @@ -139,7 +139,7 @@ func (cfg *Config) Task() task.Task { func (cfg *Config) StartProxyProviders() { b := E.NewBuilder("errors starting providers") cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { - b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName()))) + b.Add(p.Start(cfg.task.Subtask(p.String()))) }) if b.HasError() { @@ -147,7 +147,7 @@ func (cfg *Config) StartProxyProviders() { } } -func (cfg *Config) load() (res E.NestedError) { +func (cfg *Config) load() (res E.Error) { b := E.NewBuilder("errors loading config") defer b.To(&res) @@ -182,7 +182,7 @@ func (cfg *Config) load() (res E.NestedError) { return } -func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.NestedError) { +func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) { if cfg.autocertProvider != nil { return } @@ -197,7 +197,7 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested return } -func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) { +func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.Error) { subtask := cfg.task.Subtask("load providers") defer subtask.Finish("done") diff --git a/internal/docker/client.go b/internal/docker/client.go index 570f41a..d4cb5f4 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -37,7 +37,7 @@ var ( ) func init() { - task.GlobalTask("close docker clients").OnComplete("", func() { + task.GlobalTask("close docker clients").OnFinished("", func() { clientMap.RangeAllParallel(func(_ string, c Client) { if c.Connected() { c.Client.Close() @@ -70,7 +70,7 @@ func (c *SharedClient) Close() error { // Returns: // - Client: the Docker client connection. // - error: an error if the connection failed. -func ConnectClient(host string) (Client, E.NestedError) { +func ConnectClient(host string) (Client, E.Error) { clientMapMu.Lock() defer clientMapMu.Unlock() diff --git a/internal/docker/idlewatcher/loading_page.go b/internal/docker/idlewatcher/loading_page.go index 1035bc9..d545c9c 100644 --- a/internal/docker/idlewatcher/loading_page.go +++ b/internal/docker/idlewatcher/loading_page.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" "text/template" + + "github.com/yusing/go-proxy/internal/common" ) type templateData struct { @@ -18,17 +20,15 @@ type templateData struct { var loadingPage []byte var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage))) -const headerCheckRedirect = "X-Goproxy-Check-Redirect" - func (w *Watcher) makeLoadingPageBody() []byte { msg := fmt.Sprintf("%s is starting...", w.ContainerName) data := new(templateData) - data.CheckRedirectHeader = headerCheckRedirect + data.CheckRedirectHeader = common.HeaderCheckRedirect data.Title = w.ContainerName data.Message = strings.ReplaceAll(msg, " ", " ") - buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect))) + buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(common.HeaderCheckRedirect))) err := loadingPageTmpl.Execute(buf, data) if err != nil { // should never happen in production panic(err) diff --git a/internal/docker/idlewatcher/config/config.go b/internal/docker/idlewatcher/types/config.go similarity index 91% rename from internal/docker/idlewatcher/config/config.go rename to internal/docker/idlewatcher/types/config.go index 2ecaa18..a12829f 100644 --- a/internal/docker/idlewatcher/config/config.go +++ b/internal/docker/idlewatcher/types/config.go @@ -1,4 +1,4 @@ -package idlewatcher +package types import ( "time" @@ -30,7 +30,7 @@ const ( StopMethodKill StopMethod = "kill" ) -func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) { +func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { if cont == nil { return nil, nil } @@ -80,7 +80,7 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) { }, nil } -func validateDurationPostitive(value string) (time.Duration, E.NestedError) { +func validateDurationPostitive(value string) (time.Duration, E.Error) { d, err := time.ParseDuration(value) if err != nil { return 0, E.Invalid("duration", value).With(err) @@ -91,7 +91,7 @@ func validateDurationPostitive(value string) (time.Duration, E.NestedError) { return d, nil } -func validateSignal(s string) (Signal, E.NestedError) { +func validateSignal(s string) (Signal, E.Error) { switch s { case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", "INT", "TERM", "HUP", "QUIT": @@ -101,7 +101,7 @@ func validateSignal(s string) (Signal, E.NestedError) { return "", E.Invalid("signal", s) } -func validateStopMethod(s string) (StopMethod, E.NestedError) { +func validateStopMethod(s string) (StopMethod, E.Error) { sm := StopMethod(s) switch sm { case StopMethodPause, StopMethodStop, StopMethodKill: diff --git a/internal/docker/idlewatcher/types/waker.go b/internal/docker/idlewatcher/types/waker.go new file mode 100644 index 0000000..914e30e --- /dev/null +++ b/internal/docker/idlewatcher/types/waker.go @@ -0,0 +1,14 @@ +package types + +import ( + "net/http" + + net "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/watcher/health" +) + +type Waker interface { + health.HealthMonitor + http.Handler + net.Stream +} diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index ddb85c7..23630a6 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -1,10 +1,10 @@ package idlewatcher import ( - "net/http" "sync/atomic" "time" + . "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" net "github.com/yusing/go-proxy/internal/net/types" @@ -14,12 +14,6 @@ import ( "github.com/yusing/go-proxy/internal/watcher/health" ) -type Waker interface { - health.HealthMonitor - http.Handler - net.Stream -} - type waker struct { _ U.NoCopy @@ -37,7 +31,7 @@ const ( // TODO: support stream -func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) { +func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { hcCfg := entry.HealthCheckConfig() hcCfg.Timeout = idleWakerCheckTimeout @@ -62,24 +56,26 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr } // lifetime should follow route provider -func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) { +func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { return newWaker(providerSubTask, entry, rp, nil) } -func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) { +func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.Error) { return newWaker(providerSubTask, entry, nil, stream) } // Start implements health.HealthMonitor. -func (w *Watcher) Start(routeSubTask task.Task) E.NestedError { - w.task.OnComplete("stop route", func() { - routeSubTask.Parent().Finish("watcher stopped") +func (w *Watcher) Start(routeSubTask task.Task) E.Error { + routeSubTask.Finish("ignored") + w.task.OnCancel("stop route", func() { + routeSubTask.Parent().Finish(w.task.FinishCause()) }) return nil } // Finish implements health.HealthMonitor. -func (w *Watcher) Finish(reason string) {} +func (w *Watcher) Finish(reason any) { +} // Name implements health.HealthMonitor. func (w *Watcher) Name() string { @@ -109,6 +105,7 @@ func (w *Watcher) Status() health.Status { healthy, _, err := w.hc.CheckHealth() switch { case err != nil: + w.ready.Store(false) return health.StatusError case healthy: w.ready.Store(true) diff --git a/internal/docker/idlewatcher/waker_http.go b/internal/docker/idlewatcher/waker_http.go index 3333280..b7cf6c2 100644 --- a/internal/docker/idlewatcher/waker_http.go +++ b/internal/docker/idlewatcher/waker_http.go @@ -8,6 +8,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/watcher/health" @@ -37,7 +38,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN accept := gphttp.GetAccept(r.Header) acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) - isCheckRedirect := r.Header.Get(headerCheckRedirect) != "" + isCheckRedirect := r.Header.Get(common.HeaderCheckRedirect) != "" if !isCheckRedirect && acceptHTML { // Send a loading response to the client body := w.makeLoadingPageBody() @@ -56,14 +57,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout")) defer cancel() - checkCancelled := func() bool { + checkCanceled := func() bool { select { case <-w.task.Context().Done(): - w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context())) + w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context())) http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) return true case <-ctx.Done(): - w.l.Debugf("wake cancelled: %s", context.Cause(ctx)) + w.l.Debugf("wake canceled: %s", context.Cause(ctx)) http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) return true default: @@ -71,7 +72,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN } } - if checkCancelled() { + if checkCanceled() { return false } @@ -84,14 +85,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN } for { - if checkCancelled() { + if checkCanceled() { return false } if w.Status() == health.StatusHealthy { w.resetIdleTimer() if isCheckRedirect { - logrus.Debugf("container %s is ready, redirecting...", w.String()) + logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL()) rw.WriteHeader(http.StatusOK) return } diff --git a/internal/docker/idlewatcher/waker_stream.go b/internal/docker/idlewatcher/waker_stream.go index 326ebeb..1ec4174 100644 --- a/internal/docker/idlewatcher/waker_stream.go +++ b/internal/docker/idlewatcher/waker_stream.go @@ -8,44 +8,47 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/watcher/health" ) // Setup implements types.Stream. +func (w *Watcher) Addr() net.Addr { + return w.stream.Addr() +} + func (w *Watcher) Setup() error { return w.stream.Setup() } // Accept implements types.Stream. -func (w *Watcher) Accept() (conn types.StreamConn, err error) { +func (w *Watcher) Accept() (conn net.Conn, err error) { conn, err = w.stream.Accept() - // timeout means no connection is accepted - var nErr *net.OpError - ok := errors.As(err, &nErr) - if ok && nErr.Timeout() { + if err != nil { + logrus.Errorf("accept failed with error: %s", err) return } if err := w.wakeFromStream(); err != nil { - return nil, err + w.l.Error(err) } - return w.stream.Accept() -} - -// CloseListeners implements types.Stream. -func (w *Watcher) CloseListeners() { - w.stream.CloseListeners() + return } // Handle implements types.Stream. -func (w *Watcher) Handle(conn types.StreamConn) error { +func (w *Watcher) Handle(conn net.Conn) error { if err := w.wakeFromStream(); err != nil { return err } return w.stream.Handle(conn) } +// Close implements types.Stream. +func (w *Watcher) Close() error { + return w.stream.Close() +} + func (w *Watcher) wakeFromStream() error { + w.resetIdleTimer() + // pass through if container is already ready if w.ready.Load() { return nil @@ -66,11 +69,11 @@ func (w *Watcher) wakeFromStream() error { select { case <-w.task.Context().Done(): cause := w.task.FinishCause() - w.l.Debugf("wake cancelled: %s", cause) + w.l.Debugf("wake canceled: %s", cause) return cause case <-ctx.Done(): cause := context.Cause(ctx) - w.l.Debugf("wake cancelled: %s", cause) + w.l.Debugf("wake canceled: %s", cause) return cause default: } diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index 1997d5a..e5388fc 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -10,7 +10,7 @@ import ( "github.com/docker/docker/api/types/container" "github.com/sirupsen/logrus" D "github.com/yusing/go-proxy/internal/docker" - idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/task" @@ -49,7 +49,7 @@ var ( const dockerReqTimeout = 3 * time.Second -func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) { +func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) { failure := E.Failure("idle_watcher register") cfg := entry.IdlewatcherConfig() @@ -66,6 +66,7 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) w.Config = cfg w.waker = waker w.resetIdleTimer() + providerSubtask.Finish("used existing watcher") return w, nil } @@ -88,13 +89,11 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) go func() { cause := w.watchUntilDestroy() - watcherMapMu.Lock() watcherMap.Delete(w.ContainerID) - watcherMapMu.Unlock() w.ticker.Stop() w.client.Close() - w.task.Finish(cause.Error()) + w.task.Finish(cause) }() return w, nil @@ -146,7 +145,7 @@ func (w *Watcher) wakeIfStopped() error { return err } - ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout) + ctx, cancel := context.WithTimeout(w.task.Context(), w.WakeTimeout) defer cancel() // !Hard coded here since theres no constants from Docker API @@ -175,7 +174,7 @@ func (w *Watcher) getStopCallback() StopCallback { panic("should not reach here") } return func() error { - ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout) + ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.StopTimeout)*time.Second) defer cancel() return cb(ctx) } @@ -186,8 +185,8 @@ func (w *Watcher) resetIdleTimer() { w.ticker.Reset(w.IdleTimeout) } -func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) { - eventTask = w.task.Subtask("watcher for %s", w.ContainerID) +func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) { + eventTask = w.task.Subtask("docker event watcher") eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ Filters: W.NewDockerFilter( W.DockerFilterContainer, @@ -218,13 +217,12 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas func (w *Watcher) watchUntilDestroy() error { dockerWatcher := W.NewDockerWatcherWithClient(w.client) eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) + defer eventTask.Finish("stopped") for { select { case <-w.task.Context().Done(): - cause := context.Cause(w.task.Context()) - w.l.Debugf("watcher stopped by context done: %s", cause) - return cause + return w.task.FinishCause() case err := <-dockerEventErrCh: if err != nil && err.IsNot(context.Canceled) { w.l.Error(E.FailWith("docker watcher", err)) diff --git a/internal/docker/inspect.go b/internal/docker/inspect.go index 9d8d854..7220dd8 100644 --- a/internal/docker/inspect.go +++ b/internal/docker/inspect.go @@ -8,7 +8,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) { +func Inspect(dockerHost string, containerID string) (*Container, E.Error) { client, err := ConnectClient(dockerHost) defer client.Close() @@ -19,7 +19,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) return client.Inspect(containerID) } -func (c Client) Inspect(containerID string) (*Container, E.NestedError) { +func (c Client) Inspect(containerID string) (*Container, E.Error) { ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout")) defer cancel() diff --git a/internal/docker/label.go b/internal/docker/label.go index ff454d0..0d3e8e4 100644 --- a/internal/docker/label.go +++ b/internal/docker/label.go @@ -39,7 +39,7 @@ func (l *Label) String() string { // // Returns: // - error: an error if the field does not exist. -func ApplyLabel[T any](obj *T, l *Label) E.NestedError { +func ApplyLabel[T any](obj *T, l *Label) E.Error { if obj == nil { return E.Invalid("nil object", l) } @@ -81,7 +81,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError { } } -func ParseLabel(label string, value string) (*Label, E.NestedError) { +func ParseLabel(label string, value string) (*Label, E.Error) { parts := strings.Split(label, ".") if len(parts) < 2 { diff --git a/internal/docker/client_info.go b/internal/docker/list_containers.go similarity index 64% rename from internal/docker/client_info.go rename to internal/docker/list_containers.go index 751489e..285de52 100644 --- a/internal/docker/client_info.go +++ b/internal/docker/list_containers.go @@ -11,11 +11,6 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -type ClientInfo struct { - Client Client - Containers []types.Container -} - var listOptions = container.ListOptions{ // created|restarting|running|removing|paused|exited|dead // Filters: filters.NewArgs( @@ -28,28 +23,21 @@ var listOptions = container.ListOptions{ All: true, } -func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) { +func ListContainers(clientHost string) ([]types.Container, E.Error) { dockerClient, err := ConnectClient(clientHost) if err.HasError() { return nil, E.FailWith("connect to docker", err) } defer dockerClient.Close() - ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout")) + ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout")) defer cancel() - var containers []types.Container - if getContainer { - containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions)) - if err.HasError() { - return nil, E.FailWith("list containers", err) - } + containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions)) + if err.HasError() { + return nil, E.FailWith("list containers", err) } - - return &ClientInfo{ - Client: dockerClient, - Containers: containers, - }, nil + return containers, nil } func IsErrConnectionFailed(err error) bool { diff --git a/internal/error/builder.go b/internal/error/builder.go index c7d7ea1..cdab20a 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -12,7 +12,7 @@ type Builder struct { type builder struct { message string - errors []NestedError + errors []Error sync.Mutex } @@ -28,7 +28,7 @@ func NewBuilder(format string, args ...any) Builder { // adding nil is no-op, // // flatten is a boolean flag to flatten the NestedError. -func (b Builder) Add(err NestedError, flatten ...bool) { +func (b Builder) Add(err Error, flatten ...bool) { if err != nil { b.Lock() if len(flatten) > 0 && flatten[0] { @@ -54,7 +54,7 @@ func (b Builder) Addf(format string, args ...any) { } } -func (b Builder) AddRange(errs ...NestedError) { +func (b Builder) AddRange(errs ...Error) { b.Lock() defer b.Unlock() for _, err := range errs { @@ -77,14 +77,14 @@ func (b Builder) AddRangeE(errs ...error) { // // Returns: // - NestedError: the built NestedError. -func (b Builder) Build() NestedError { +func (b Builder) Build() Error { if len(b.errors) == 0 { return nil } return Join(b.message, b.errors...) } -func (b Builder) To(ptr *NestedError) { +func (b Builder) To(ptr *Error) { switch { case ptr == nil: return diff --git a/internal/error/builder_test.go b/internal/error/builder_test.go index 0f3c613..2c24948 100644 --- a/internal/error/builder_test.go +++ b/internal/error/builder_test.go @@ -16,7 +16,7 @@ func TestBuilderEmpty(t *testing.T) { func TestBuilderAddNil(t *testing.T) { eb := NewBuilder("asdf") - var err NestedError + var err Error for range 3 { eb.Add(nil) } @@ -53,7 +53,7 @@ func TestBuilderTo(t *testing.T) { eb := NewBuilder("error occurred") eb.Addf("abcd") - var err NestedError + var err Error eb.To(&err) got := err.String() expected := (`error occurred: diff --git a/internal/error/error.go b/internal/error/error.go index cd95dd8..f4baa27 100644 --- a/internal/error/error.go +++ b/internal/error/error.go @@ -8,35 +8,35 @@ import ( ) type ( - NestedError = *NestedErrorImpl - NestedErrorImpl struct { + Error = *ErrorImpl + ErrorImpl struct { subject string err error - extras []NestedErrorImpl + extras []ErrorImpl } - JSONNestedError struct { - Subject string `json:"subject"` - Err string `json:"error"` - Extras []JSONNestedError `json:"extras,omitempty"` + ErrorJSONMarshaller struct { + Subject string `json:"subject"` + Err string `json:"error"` + Extras []ErrorJSONMarshaller `json:"extras,omitempty"` } ) -func From(err error) NestedError { +func From(err error) Error { if IsNil(err) { return nil } - return &NestedErrorImpl{err: err} + return &ErrorImpl{err: err} } -func FromJSON(data []byte) (NestedError, bool) { - var j JSONNestedError +func FromJSON(data []byte) (Error, bool) { + var j ErrorJSONMarshaller if err := json.Unmarshal(data, &j); err != nil { return nil, false } if j.Err == "" { return nil, false } - extras := make([]NestedErrorImpl, len(j.Extras)) + extras := make([]ErrorImpl, len(j.Extras)) for i, e := range j.Extras { extra, ok := fromJSONObject(e) if !ok { @@ -44,7 +44,7 @@ func FromJSON(data []byte) (NestedError, bool) { } extras[i] = *extra } - return &NestedErrorImpl{ + return &ErrorImpl{ subject: j.Subject, err: errors.New(j.Err), extras: extras, @@ -53,12 +53,12 @@ func FromJSON(data []byte) (NestedError, bool) { // Check is a helper function that // convert (T, error) to (T, NestedError). -func Check[T any](obj T, err error) (T, NestedError) { +func Check[T any](obj T, err error) (T, Error) { return obj, From(err) } -func Join(message string, err ...NestedError) NestedError { - extras := make([]NestedErrorImpl, len(err)) +func Join(message string, err ...Error) Error { + extras := make([]ErrorImpl, len(err)) nErr := 0 for i, e := range err { if e == nil { @@ -70,13 +70,13 @@ func Join(message string, err ...NestedError) NestedError { if nErr == 0 { return nil } - return &NestedErrorImpl{ + return &ErrorImpl{ err: errors.New(message), extras: extras, } } -func JoinE(message string, err ...error) NestedError { +func JoinE(message string, err ...error) Error { b := NewBuilder("%s", message) for _, e := range err { b.AddE(e) @@ -92,13 +92,13 @@ func IsNotNil(err error) bool { return err != nil } -func (ne NestedError) String() string { +func (ne Error) String() string { var buf strings.Builder ne.writeToSB(&buf, 0, "") return buf.String() } -func (ne NestedError) Is(err error) bool { +func (ne Error) Is(err error) bool { if ne == nil { return err == nil } @@ -114,18 +114,18 @@ func (ne NestedError) Is(err error) bool { return false } -func (ne NestedError) IsNot(err error) bool { +func (ne Error) IsNot(err error) bool { return !ne.Is(err) } -func (ne NestedError) Error() error { +func (ne Error) Error() error { if ne == nil { return nil } return ne.buildError(0, "") } -func (ne NestedError) With(s any) NestedError { +func (ne Error) With(s any) Error { if ne == nil { return ne } @@ -133,7 +133,7 @@ func (ne NestedError) With(s any) NestedError { switch ss := s.(type) { case nil: return ne - case *NestedErrorImpl: + case *ErrorImpl: if len(ss.extras) == 1 { ne.extras = append(ne.extras, ss.extras[0]) return ne @@ -151,11 +151,11 @@ func (ne NestedError) With(s any) NestedError { return ne.withError(From(errors.New(msg))) } -func (ne NestedError) Extraf(format string, args ...any) NestedError { +func (ne Error) Extraf(format string, args ...any) Error { return ne.With(errorf(format, args...)) } -func (ne NestedError) Subject(s any, sep ...string) NestedError { +func (ne Error) Subject(s any, sep ...string) Error { if ne == nil { return ne } @@ -179,26 +179,26 @@ func (ne NestedError) Subject(s any, sep ...string) NestedError { return ne } -func (ne NestedError) Subjectf(format string, args ...any) NestedError { +func (ne Error) Subjectf(format string, args ...any) Error { if ne == nil { return ne } return ne.Subject(fmt.Sprintf(format, args...)) } -func (ne NestedError) JSONObject() JSONNestedError { - extras := make([]JSONNestedError, len(ne.extras)) +func (ne Error) JSONObject() ErrorJSONMarshaller { + extras := make([]ErrorJSONMarshaller, len(ne.extras)) for i, e := range ne.extras { extras[i] = e.JSONObject() } - return JSONNestedError{ + return ErrorJSONMarshaller{ Subject: ne.subject, Err: ne.err.Error(), Extras: extras, } } -func (ne NestedError) JSON() []byte { +func (ne Error) JSON() []byte { b, err := json.MarshalIndent(ne.JSONObject(), "", " ") if err != nil { panic(err) @@ -206,19 +206,19 @@ func (ne NestedError) JSON() []byte { return b } -func (ne NestedError) NoError() bool { +func (ne Error) NoError() bool { return ne == nil } -func (ne NestedError) HasError() bool { +func (ne Error) HasError() bool { return ne != nil } -func errorf(format string, args ...any) NestedError { +func errorf(format string, args ...any) Error { return From(fmt.Errorf(format, args...)) } -func fromJSONObject(obj JSONNestedError) (NestedError, bool) { +func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) { data, err := json.Marshal(obj) if err != nil { return nil, false @@ -226,14 +226,14 @@ func fromJSONObject(obj JSONNestedError) (NestedError, bool) { return FromJSON(data) } -func (ne NestedError) withError(err NestedError) NestedError { +func (ne Error) withError(err Error) Error { if ne != nil && err != nil { ne.extras = append(ne.extras, *err) } return ne } -func (ne NestedError) appendMsg(msg string) NestedError { +func (ne Error) appendMsg(msg string) Error { if ne == nil { return nil } @@ -241,7 +241,7 @@ func (ne NestedError) appendMsg(msg string) NestedError { return ne } -func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { +func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) { for range level { sb.WriteString(" ") } @@ -266,7 +266,7 @@ func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { } } -func (ne NestedError) buildError(level int, prefix string) error { +func (ne Error) buildError(level int, prefix string) error { var res error var sb strings.Builder diff --git a/internal/error/error_test.go b/internal/error/error_test.go index 864d313..5d0b9ea 100644 --- a/internal/error/error_test.go +++ b/internal/error/error_test.go @@ -26,7 +26,7 @@ func TestErrorIs(t *testing.T) { } func TestErrorNestedIs(t *testing.T) { - var err NestedError + var err Error ExpectTrue(t, err.Is(nil)) err = Failure("some reason") @@ -40,7 +40,7 @@ func TestErrorNestedIs(t *testing.T) { } func TestIsNil(t *testing.T) { - var err NestedError + var err Error ExpectTrue(t, err.Is(nil)) ExpectTrue(t, err == nil) ExpectTrue(t, err.NoError()) diff --git a/internal/error/errors.go b/internal/error/errors.go index cdc7d9a..4ae9214 100644 --- a/internal/error/errors.go +++ b/internal/error/errors.go @@ -22,62 +22,62 @@ var ( const fmtSubjectWhat = "%w %v: %q" -func Failure(what string) NestedError { +func Failure(what string) Error { return errorf("%s %w", what, ErrFailure) } -func FailedWhy(what string, why string) NestedError { +func FailedWhy(what string, why string) Error { return Failure(what).With(why) } -func FailWith(what string, err any) NestedError { +func FailWith(what string, err any) Error { return Failure(what).With(err) } -func Invalid(subject, what any) NestedError { +func Invalid(subject, what any) Error { return errorf(fmtSubjectWhat, ErrInvalid, subject, what) } -func Unsupported(subject, what any) NestedError { +func Unsupported(subject, what any) Error { return errorf(fmtSubjectWhat, ErrUnsupported, subject, what) } -func Unexpected(subject, what any) NestedError { +func Unexpected(subject, what any) Error { return errorf(fmtSubjectWhat, ErrUnexpected, subject, what) } -func UnexpectedError(err error) NestedError { +func UnexpectedError(err error) Error { return errorf("%w error: %w", ErrUnexpected, err) } -func NotExist(subject, what any) NestedError { +func NotExist(subject, what any) Error { return errorf("%v %w: %v", subject, ErrNotExists, what) } -func Missing(subject any) NestedError { +func Missing(subject any) Error { return errorf("%w %v", ErrMissing, subject) } -func Duplicated(subject, what any) NestedError { +func Duplicated(subject, what any) Error { return errorf("%w %v: %v", ErrDuplicated, subject, what) } -func OutOfRange(subject any, value any) NestedError { +func OutOfRange(subject any, value any) Error { return errorf("%v %w: %v", subject, ErrOutOfRange, value) } -func TypeError(subject any, from, to reflect.Type) NestedError { +func TypeError(subject any, from, to reflect.Type) Error { return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to) } -func TypeError2(subject any, from, to reflect.Value) NestedError { +func TypeError2(subject any, from, to reflect.Value) Error { return TypeError(subject, from.Type(), to.Type()) } -func TypeMismatch[Expect any](value any) NestedError { +func TypeMismatch[Expect any](value any) Error { return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) } -func PanicRecv(format string, args ...any) NestedError { +func PanicRecv(format string, args ...any) Error { return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...)) } diff --git a/internal/net/http/loadbalancer/dummy_response_writer.go b/internal/net/http/loadbalancer/dummy_response_writer.go new file mode 100644 index 0000000..d6ea9f0 --- /dev/null +++ b/internal/net/http/loadbalancer/dummy_response_writer.go @@ -0,0 +1,15 @@ +package loadbalancer + +import "net/http" + +type DummyResponseWriter struct{} + +func (w *DummyResponseWriter) Header() (_ http.Header) { + return +} + +func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) { + return +} + +func (w *DummyResponseWriter) WriteHeader(int) {} diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index 447420f..62e1a51 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -21,7 +21,7 @@ func (lb *LoadBalancer) newIPHash() impl { if len(lb.Options) == 0 { return impl } - var err E.NestedError + var err E.Error impl.realIP, err = middleware.NewRealIP(lb.Options) if err != nil { logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err) diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index c7c6cc0..ed860e6 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -1,10 +1,13 @@ package loadbalancer import ( + "context" "net/http" "sync" "time" + "github.com/yusing/go-proxy/internal/common" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/task" @@ -54,10 +57,10 @@ func New(cfg *Config) *LoadBalancer { } // Start implements task.TaskStarter. -func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError { +func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error { lb.startTime = time.Now() lb.task = routeSubtask - lb.task.OnComplete("loadbalancer cleanup", func() { + lb.task.OnFinished("loadbalancer cleanup", func() { if lb.impl != nil { lb.pool.RangeAll(func(k string, v *Server) { lb.impl.OnRemoveServer(v) @@ -69,7 +72,7 @@ func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError { } // Finish implements task.TaskFinisher. -func (lb *LoadBalancer) Finish(reason string) { +func (lb *LoadBalancer) Finish(reason any) { lb.task.Finish(reason) } @@ -128,7 +131,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) { lb.rebalance() lb.impl.OnAddServer(srv) - logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size()) + logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size()) } func (lb *LoadBalancer) RemoveServer(srv *Server) { @@ -147,11 +150,11 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { if lb.pool.Size() == 0 { lb.task.Finish("no server left") - logger.Infof("[remove] loadbalancer %s stopped", lb.Link) + logger.Infof("loadbalancer %s stopped", lb.Link) return } - logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size()) + logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size()) } func (lb *LoadBalancer) rebalance() { @@ -211,6 +214,21 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) return } + if r.Header.Get(common.HeaderCheckRedirect) != "" { + ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) + defer cancel() + // send dummy request to wake all servers + var dummyRW *DummyResponseWriter + for _, srv := range srvs { + // wake only if server implements Waker + _, ok := srv.handler.(idlewatcher.Waker) + if !ok { + continue + } + wakeReq := r.Clone(ctx) + srv.ServeHTTP(dummyRW, wakeReq) + } + } lb.impl.ServeHTTP(srvs, rw, r) } @@ -261,10 +279,9 @@ func (lb *LoadBalancer) String() string { func (lb *LoadBalancer) availServers() []*Server { avail := make([]*Server, 0, lb.pool.Size()) lb.pool.RangeAll(func(_ string, srv *Server) { - if srv.Status().Bad() { - return + if srv.Status().Good() { + avail = append(avail, srv) } - avail = append(avail, srv) }) return avail } diff --git a/internal/net/http/loadbalancer/round_robin.go b/internal/net/http/loadbalancer/round_robin.go index 557d4e3..41e70c8 100644 --- a/internal/net/http/loadbalancer/round_robin.go +++ b/internal/net/http/loadbalancer/round_robin.go @@ -14,8 +14,8 @@ func (lb *roundRobin) OnAddServer(srv *Server) {} func (lb *roundRobin) OnRemoveServer(srv *Server) {} func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) { - index := lb.index.Add(1) - srvs[index%uint32(len(srvs))].ServeHTTP(rw, r) + index := lb.index.Add(1) % uint32(len(srvs)) + srvs[index].ServeHTTP(rw, r) if lb.index.Load() >= 2*uint32(len(srvs)) { lb.index.Store(0) } diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index 502d6f1..5702851 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -35,7 +35,7 @@ var cidrWhitelistDefaults = func() *cidrWhitelistOpts { } } -func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) { +func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { wl := new(cidrWhitelist) wl.m = &Middleware{ impl: wl, diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index 7de7326..20e7ff3 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -33,7 +33,7 @@ var CloudflareRealIP = &realIP{ m: &Middleware{withOptions: NewCloudflareRealIP}, } -func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) { +func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) { cri := new(realIP) cri.m = &Middleware{ impl: cri, diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 63781a5..93e51d5 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -36,7 +36,7 @@ var ForwardAuth = &forwardAuth{ m: &Middleware{withOptions: NewForwardAuthfunc}, } -func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) { +func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { fa := new(forwardAuth) fa.forwardAuthOpts = new(forwardAuthOpts) err := Deserialize(optsRaw, fa.forwardAuthOpts) diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 502d18a..804450a 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -11,7 +11,7 @@ import ( ) type ( - Error = E.NestedError + Error = E.Error ReverseProxy = gphttp.ReverseProxy ProxyRequest = gphttp.ProxyRequest @@ -24,7 +24,7 @@ type ( BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request) RewriteFunc func(req *Request) ModifyResponseFunc func(resp *Response) error - CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError) + CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error) OptionsRaw = map[string]any Options any @@ -77,7 +77,7 @@ func (m *Middleware) MarshalJSON() ([]byte, error) { }, "", " ") } -func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) { +func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) { if len(optsRaw) != 0 && m.withOptions != nil { return m.withOptions(optsRaw) } @@ -108,7 +108,7 @@ func (m *Middleware) ModifyResponse(resp *Response) error { } // TODO: check conflict or duplicates. -func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) { +func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) { middlewares = make([]*Middleware, 0, len(middlewaresMap)) invalidM := E.NewBuilder("invalid middlewares") @@ -136,7 +136,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid return } -func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) { +func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { var middlewares []*Middleware middlewares, err = createMiddlewares(middlewaresMap) if err != nil { diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index d2b760a..19bd92f 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -10,7 +10,7 @@ import ( "gopkg.in/yaml.v3" ) -func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) { +func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) { fileContent, err := os.ReadFile(filePath) if err != nil { return nil, E.FailWith("read middleware compose file", err) @@ -18,7 +18,7 @@ func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E return BuildMiddlewaresFromYAML(fileContent) } -func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) { +func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) { b := E.NewBuilder("middlewares compile errors") defer b.To(&outErr) diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 261d4f0..2b9ca2e 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -22,7 +22,7 @@ var ModifyRequest = &modifyRequest{ m: &Middleware{withOptions: NewModifyRequest}, } -func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) { +func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyRequest) var mrFunc RewriteFunc if common.IsDebug { diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index dd2ad24..62011d8 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -24,7 +24,7 @@ var ModifyResponse = &modifyResponse{ m: &Middleware{withOptions: NewModifyResponse}, } -func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) { +func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyResponse) mr.m = &Middleware{impl: mr} if common.IsDebug { diff --git a/internal/net/http/middleware/oauth2.go b/internal/net/http/middleware/oauth2.go index 6b352fa..b500bee 100644 --- a/internal/net/http/middleware/oauth2.go +++ b/internal/net/http/middleware/oauth2.go @@ -26,7 +26,7 @@ var OAuth2 = &oAuth2{ m: &Middleware{withOptions: NewAuthentikOAuth2}, } -func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) { +func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { oauth := new(oAuth2) oauth.m = &Middleware{ impl: oauth, diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index f3ebc25..12d674c 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -41,7 +41,7 @@ var realIPOptsDefault = func() *realIPOpts { } } -func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) { +func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) { riWithOpts := new(realIP) riWithOpts.m = &Middleware{ impl: riWithOpts, diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index 4f21c77..b62d1d6 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -72,7 +72,7 @@ type testArgs struct { scheme string } -func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) { +func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { var body io.Reader var rr requestRecorder var proxyURL *url.URL diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index 1317757..230ca16 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -9,7 +9,7 @@ import ( type CIDR net.IPNet -func (cidr *CIDR) ConvertFrom(val any) E.NestedError { +func (cidr *CIDR) ConvertFrom(val any) E.Error { cidrStr, ok := val.(string) if !ok { return E.TypeMismatch[string](val) diff --git a/internal/net/types/stream.go b/internal/net/types/stream.go index 6306089..871521f 100644 --- a/internal/net/types/stream.go +++ b/internal/net/types/stream.go @@ -7,13 +7,7 @@ import ( type Stream interface { fmt.Stringer + net.Listener Setup() error - Accept() (conn StreamConn, err error) - Handle(conn StreamConn) error - CloseListeners() -} - -type StreamConn interface { - RemoteAddr() net.Addr - Close() error + Handle(conn net.Conn) error } diff --git a/internal/proxy/entry/entry.go b/internal/proxy/entry/entry.go index 9e80ac3..0c17ae8 100644 --- a/internal/proxy/entry/entry.go +++ b/internal/proxy/entry/entry.go @@ -1,7 +1,7 @@ package entry import ( - idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" net "github.com/yusing/go-proxy/internal/net/types" @@ -18,7 +18,7 @@ type Entry interface { IdlewatcherConfig() *idlewatcher.Config } -func ValidateEntry(m *RawEntry) (Entry, E.NestedError) { +func ValidateEntry(m *RawEntry) (Entry, E.Error) { m.FillMissingFields() scheme, err := T.NewScheme(m.Scheme) diff --git a/internal/proxy/entry/reverse_proxy.go b/internal/proxy/entry/reverse_proxy.go index 95f4352..7d62973 100644 --- a/internal/proxy/entry/reverse_proxy.go +++ b/internal/proxy/entry/reverse_proxy.go @@ -5,7 +5,7 @@ import ( "net/url" "github.com/yusing/go-proxy/internal/docker" - idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" net "github.com/yusing/go-proxy/internal/net/types" diff --git a/internal/proxy/entry/stream.go b/internal/proxy/entry/stream.go index dd74de2..35ac4ef 100644 --- a/internal/proxy/entry/stream.go +++ b/internal/proxy/entry/stream.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/yusing/go-proxy/internal/docker" - idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" + idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" net "github.com/yusing/go-proxy/internal/net/types" diff --git a/internal/proxy/fields/headers.go b/internal/proxy/fields/headers.go index 86a9837..109eb44 100644 --- a/internal/proxy/fields/headers.go +++ b/internal/proxy/fields/headers.go @@ -7,7 +7,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) { +func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.Error) { h := make(http.Header) for k, v := range headers { vSplit := strings.Split(v, ",") diff --git a/internal/proxy/fields/host.go b/internal/proxy/fields/host.go index 446c051..68c17c2 100644 --- a/internal/proxy/fields/host.go +++ b/internal/proxy/fields/host.go @@ -9,6 +9,6 @@ type ( Subdomain = Alias ) -func ValidateHost[String ~string](s String) (Host, E.NestedError) { +func ValidateHost[String ~string](s String) (Host, E.Error) { return Host(s), nil } diff --git a/internal/proxy/fields/path_pattern.go b/internal/proxy/fields/path_pattern.go index 0a42ce5..8d9abd3 100644 --- a/internal/proxy/fields/path_pattern.go +++ b/internal/proxy/fields/path_pattern.go @@ -13,7 +13,7 @@ type ( var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) -func ValidatePathPattern(s string) (PathPattern, E.NestedError) { +func ValidatePathPattern(s string) (PathPattern, E.Error) { if len(s) == 0 { return "", E.Invalid("path", "must not be empty") } @@ -23,7 +23,7 @@ func ValidatePathPattern(s string) (PathPattern, E.NestedError) { return PathPattern(s), nil } -func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) { +func ValidatePathPatterns(s []string) (PathPatterns, E.Error) { if len(s) == 0 { return []PathPattern{"/"}, nil } diff --git a/internal/proxy/fields/port.go b/internal/proxy/fields/port.go index 9d517e3..5780005 100644 --- a/internal/proxy/fields/port.go +++ b/internal/proxy/fields/port.go @@ -8,7 +8,7 @@ import ( type Port int -func ValidatePort[String ~string](v String) (Port, E.NestedError) { +func ValidatePort[String ~string](v String) (Port, E.Error) { p, err := strconv.Atoi(string(v)) if err != nil { return ErrPort, E.Invalid("port number", v).With(err) @@ -16,7 +16,7 @@ func ValidatePort[String ~string](v String) (Port, E.NestedError) { return ValidatePortInt(p) } -func ValidatePortInt[Int int | uint16](v Int) (Port, E.NestedError) { +func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) { p := Port(v) if !p.inBound() { return ErrPort, E.OutOfRange("port", p) diff --git a/internal/proxy/fields/scheme.go b/internal/proxy/fields/scheme.go index 457b9cd..2e4f6e5 100644 --- a/internal/proxy/fields/scheme.go +++ b/internal/proxy/fields/scheme.go @@ -6,7 +6,7 @@ import ( type Scheme string -func NewScheme[String ~string](s String) (Scheme, E.NestedError) { +func NewScheme[String ~string](s String) (Scheme, E.Error) { switch s { case "http", "https", "tcp", "udp": return Scheme(s), nil diff --git a/internal/proxy/fields/stream_port.go b/internal/proxy/fields/stream_port.go index 75824f8..020d455 100644 --- a/internal/proxy/fields/stream_port.go +++ b/internal/proxy/fields/stream_port.go @@ -12,7 +12,7 @@ type StreamPort struct { ProxyPort Port `json:"proxy"` } -func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) { +func ValidateStreamPort(p string) (_ StreamPort, err E.Error) { split := strings.Split(p, ":") switch len(split) { @@ -47,7 +47,7 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) { return StreamPort{listeningPort, proxyPort}, nil } -func parseNameToPort(name string) (Port, E.NestedError) { +func parseNameToPort(name string) (Port, E.Error) { port, ok := common.ServiceNamePortMapTCP[name] if !ok { return ErrPort, E.Invalid("service", name) diff --git a/internal/proxy/fields/stream_scheme.go b/internal/proxy/fields/stream_scheme.go index 17835db..d195a29 100644 --- a/internal/proxy/fields/stream_scheme.go +++ b/internal/proxy/fields/stream_scheme.go @@ -12,7 +12,7 @@ type StreamScheme struct { ProxyScheme Scheme `json:"proxy"` } -func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { +func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) { ss = &StreamScheme{} parts := strings.Split(s, ":") if len(parts) == 1 { diff --git a/internal/route/http.go b/internal/route/http.go index da3a964..0010e3c 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) { } } -func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) { +func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) { var trans *http.Transport if entry.NoTLSVerify { @@ -97,7 +97,7 @@ func (r *HTTPRoute) String() string { } // Start implements task.TaskStarter. -func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError { +func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { if entry.ShouldNotServe(r) { providerSubtask.Finish("should not serve") return nil @@ -151,7 +151,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError { r.addToLoadBalancer() } else { httpRoutes.Store(string(r.Alias), r) - r.task.OnComplete("stop rp", func() { + r.task.OnFinished("remove from route table", func() { httpRoutes.Delete(string(r.Alias)) }) } @@ -160,7 +160,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError { } // Finish implements task.TaskFinisher. -func (r *HTTPRoute) Finish(reason string) { +func (r *HTTPRoute) Finish(reason any) { r.task.Finish(reason) } @@ -175,8 +175,8 @@ func (r *HTTPRoute) addToLoadBalancer() { } } else { lb = loadbalancer.New(r.LoadBalance) - lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link) - lbTask.OnComplete("remove lb from routes", func() { + lbTask := r.task.Parent().Subtask("loadbalancer " + r.LoadBalance.Link) + lbTask.OnCancel("remove lb from routes", func() { httpRoutes.Delete(r.LoadBalance.Link) }) lb.Start(lbTask) @@ -194,9 +194,9 @@ func (r *HTTPRoute) addToLoadBalancer() { httpRoutes.Store(r.LoadBalance.Link, linked) } r.loadBalancer = lb - r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) + r.server = loadbalancer.NewServer(r.task.String(), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) lb.AddServer(r.server) - r.task.OnComplete("remove server from lb", func() { + r.task.OnCancel("remove server from lb", func() { lb.RemoveServer(r.server) }) } diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go index 2d6afd8..440e47b 100755 --- a/internal/route/provider/docker.go +++ b/internal/route/provider/docker.go @@ -25,7 +25,7 @@ var ( AliasRefRegexOld = regexp.MustCompile(`\$\d+`) ) -func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) { +func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) { if dockerHost == common.DockerHostFromEnv { dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost) } @@ -40,18 +40,18 @@ func (p *DockerProvider) NewWatcher() W.Watcher { return W.NewDockerWatcher(p.dockerHost) } -func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { +func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) { routes = R.NewRoutes() entries := entry.NewProxyEntries() - info, err := D.GetClientInfo(p.dockerHost, true) + containers, err := D.ListContainers(p.dockerHost) if err != nil { - return routes, E.FailWith("connect to docker", err) + return routes, err } errors := E.NewBuilder("errors in docker labels") - for _, c := range info.Containers { + for _, c := range containers { container := D.FromDocker(&c, p.dockerHost) if container.IsExcluded { continue @@ -70,10 +70,6 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { }) } - entries.RangeAll(func(_ string, e *entry.RawEntry) { - e.Container.DockerHost = p.dockerHost - }) - routes, err = R.FromEntries(entries) errors.Add(err) @@ -89,7 +85,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool { // Returns a list of proxy entries for a container. // Always non-nil. -func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) { +func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) { entries = entry.NewProxyEntries() if p.shouldIgnore(container) { @@ -117,7 +113,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent return entries, errors.Build().Subject(container.ContainerName) } -func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) { +func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) { b := E.NewBuilder("errors in label %s", key) defer b.To(&res) diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index f10ba43..a7fee91 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -1,7 +1,9 @@ package provider import ( + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher" @@ -32,31 +34,52 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { return } - oldRoutes.RangeAll(func(k string, v *route.Route) { - if !newRoutes.Has(k) { - handler.Remove(v) + if common.IsDebug { + eventsLog := E.NewBuilder("events") + for _, event := range events { + eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID) + } + handler.provider.l.Debug(eventsLog.String()) + oldRoutesLog := E.NewBuilder("old routes") + oldRoutes.RangeAll(func(k string, r *route.Route) { + oldRoutesLog.Addf(k) + }) + handler.provider.l.Debug(oldRoutesLog.String()) + newRoutesLog := E.NewBuilder("new routes") + newRoutes.RangeAll(func(k string, r *route.Route) { + newRoutesLog.Addf(k) + }) + handler.provider.l.Debug(newRoutesLog.String()) + } + + oldRoutes.RangeAll(func(k string, oldr *route.Route) { + newr, ok := newRoutes.Load(k) + if !ok { + handler.Remove(oldr) + } else if handler.matchAny(events, newr) { + handler.Update(parent, oldr, newr) + } else if entry.ShouldNotServe(newr) { + handler.Remove(oldr) } }) newRoutes.RangeAll(func(k string, newr *route.Route) { - if oldRoutes.Has(k) { - for _, ev := range events { - if handler.match(ev, newr) { - old, ok := oldRoutes.Load(k) - if !ok { // should not happen - panic("race condition") - } - handler.Update(parent, old, newr) - return - } - } - } else { + if !(oldRoutes.Has(k) || entry.ShouldNotServe(newr)) { handler.Add(parent, newr) } }) } +func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route) bool { + for _, event := range events { + if handler.match(event, route) { + return true + } + } + return false +} + func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { - switch handler.provider.t { + switch handler.provider.GetType() { case ProviderTypeDocker: return route.Entry.Container.ContainerID == event.ActorID || route.Entry.Container.ContainerName == event.ActorName @@ -70,14 +93,15 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool func (handler *EventHandler) Add(parent task.Task, route *route.Route) { err := handler.provider.startRoute(parent, route) if err != nil { - handler.errs.Add(err) + handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err)) } else { handler.added = append(handler.added, route.Entry.Alias) } } func (handler *EventHandler) Remove(route *route.Route) { - route.Finish("route removal") + route.Finish("route removed") + handler.provider.routes.Delete(route.Entry.Alias) handler.removed = append(handler.removed, route.Entry.Alias) } @@ -85,7 +109,7 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new oldRoute.Finish("route update") err := handler.provider.startRoute(parent, newRoute) if err != nil { - handler.errs.Add(err) + handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err)) } else { handler.updated = append(handler.updated, newRoute.Entry.Alias) } diff --git a/internal/route/provider/file.go b/internal/route/provider/file.go index eefe313..472ca9d 100644 --- a/internal/route/provider/file.go +++ b/internal/route/provider/file.go @@ -18,7 +18,7 @@ type FileProvider struct { path string } -func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) { +func FileProviderImpl(filename string) (ProviderImpl, E.Error) { impl := &FileProvider{ fileName: filename, path: path.Join(common.ConfigBasePath, filename), @@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) { } } -func Validate(data []byte) E.NestedError { +func Validate(data []byte) E.Error { return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data) } @@ -42,7 +42,7 @@ func (p FileProvider) String() string { return p.fileName } -func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) { +func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) { routes = R.NewRoutes() b := E.NewBuilder("validation failure") diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index ec2f965..195d910 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -7,7 +7,6 @@ import ( "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/proxy/entry" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/task" W "github.com/yusing/go-proxy/internal/watcher" @@ -29,7 +28,7 @@ type ( ProviderImpl interface { fmt.Stringer NewWatcher() W.Watcher - LoadRoutesImpl() (R.Routes, E.NestedError) + LoadRoutesImpl() (R.Routes, E.Error) } ProviderType string ProviderStats struct { @@ -43,7 +42,7 @@ const ( ProviderTypeDocker ProviderType = "docker" ProviderTypeFile ProviderType = "file" - providerEventFlushInterval = 500 * time.Millisecond + providerEventFlushInterval = 300 * time.Millisecond ) func newProvider(name string, t ProviderType) *Provider { @@ -56,7 +55,7 @@ func newProvider(name string, t ProviderType) *Provider { return p } -func NewFileProvider(filename string) (p *Provider, err E.NestedError) { +func NewFileProvider(filename string) (p *Provider, err E.Error) { name := path.Base(filename) if name == "" { return nil, E.Invalid("file name", "empty") @@ -70,7 +69,7 @@ func NewFileProvider(filename string) (p *Provider, err E.NestedError) { return } -func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) { +func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) { if name == "" { return nil, E.Invalid("provider name", "empty") } @@ -101,18 +100,16 @@ func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } -func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError { - if entry.UseLoadBalance(r) { - r.Entry.Alias = p.String() + "/" + r.Entry.Alias - } - subtask := parent.Subtask(r.Entry.Alias) +func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { + subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias) err := r.Start(subtask) if err != nil { p.routes.Delete(r.Entry.Alias) - subtask.Finish(err.String()) // just to ensure + subtask.Finish(err) // just to ensure return err } else { - subtask.OnComplete("del from provider", func() { + p.routes.Store(r.Entry.Alias, r) + subtask.OnFinished("del from provider", func() { p.routes.Delete(r.Entry.Alias) }) } @@ -120,7 +117,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError { } // Start implements task.TaskStarter. -func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) { +func (p *Provider) Start(configSubtask task.Task) (res E.Error) { errors := E.NewBuilder("errors starting routes") defer errors.To(&res) @@ -141,7 +138,7 @@ func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) { handler.Log() flushTask.Finish("events flushed") }, - func(err E.NestedError) { + func(err E.Error) { p.l.Error(err) }, ) @@ -157,8 +154,8 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) { return p.routes.Load(alias) } -func (p *Provider) LoadRoutes() E.NestedError { - var err E.NestedError +func (p *Provider) LoadRoutes() E.Error { + var err E.Error p.routes, err = p.LoadRoutesImpl() if p.routes.Size() > 0 { return err diff --git a/internal/route/raw.go b/internal/route/raw.go new file mode 100644 index 0000000..f206a74 --- /dev/null +++ b/internal/route/raw.go @@ -0,0 +1,94 @@ +package route + +import ( + "errors" + "fmt" + "net" + "time" + + T "github.com/yusing/go-proxy/internal/proxy/fields" + U "github.com/yusing/go-proxy/internal/utils" +) + +type ( + RawStream struct { + *StreamRoute + + listener net.Listener + targetAddr net.Addr + } +) + +const ( + streamBufferSize = 8192 + streamDialTimeout = 5 * time.Second +) + +func NewRawStreamRoute(base *StreamRoute) *RawStream { + return &RawStream{ + StreamRoute: base, + } +} + +func (route *RawStream) Setup() error { + var lcfg net.ListenConfig + var err error + + switch route.Scheme.ListeningScheme { + case "tcp": + route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) + if err != nil { + return err + } + tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) + if err != nil { + return err + } + route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) + route.listener = tcpListener + case "udp": + route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) + if err != nil { + return err + } + udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort)) + if err != nil { + return err + } + route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port) + route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener) + default: + return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme)) + } + + return nil +} + +func (route *RawStream) Accept() (net.Conn, error) { + if route.listener == nil { + return nil, errors.New("listener not yet set up") + } + return route.listener.Accept() +} + +func (route *RawStream) Handle(c net.Conn) error { + clientConn := c.(net.Conn) + + defer clientConn.Close() + route.task.OnCancel("close conn", func() { clientConn.Close() }) + + dialer := &net.Dialer{Timeout: streamDialTimeout} + + serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) + serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr) + if err != nil { + return err + } + + pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) + return pipe.Start() +} + +func (route *RawStream) Close() error { + return route.listener.Close() +} diff --git a/internal/route/route.go b/internal/route/route.go index 6bcf114..93b3a4c 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -44,7 +44,7 @@ func (rt *Route) Container() *docker.Container { return rt.Entry.Container } -func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) { +func NewRoute(raw *entry.RawEntry) (*Route, E.Error) { en, err := entry.ValidateEntry(raw) if err != nil { return nil, err @@ -73,7 +73,7 @@ func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) { }, nil } -func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) { +func FromEntries(entries entry.RawEntries) (Routes, E.Error) { b := E.NewBuilder("errors in routes") routes := NewRoutes() diff --git a/internal/route/stream.go b/internal/route/stream.go index 1622a06..957a166 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - stdNet "net" "sync" "github.com/sirupsen/logrus" @@ -37,7 +36,7 @@ func GetStreamProxies() F.Map[string, *StreamRoute] { return streamRoutes } -func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) { +func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) { // TODO: support non-coherent scheme if !entry.Scheme.IsCoherent() { return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) @@ -48,16 +47,12 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) { }, nil } -func (r *StreamRoute) Finish(reason string) { - r.task.Finish(reason) -} - func (r *StreamRoute) String() string { return fmt.Sprintf("stream %s", r.Alias) } // Start implements task.TaskStarter. -func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError { +func (r *StreamRoute) Start(providerSubtask task.Task) E.Error { if entry.ShouldNotServe(r) { providerSubtask.Finish("should not serve") return nil @@ -71,11 +66,13 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError { r.HealthCheck.Disable = true } - if r.Scheme.ListeningScheme.IsTCP() { - r.Stream = NewTCPRoute(r) - } else { - r.Stream = NewUDPRoute(r) - } + // if r.Scheme.ListeningScheme.IsTCP() { + // r.Stream = NewTCPRoute(r) + // } else { + // r.Stream = NewUDPRoute(r) + // } + r.task = providerSubtask + r.Stream = NewRawStreamRoute(r) r.l = logrus.WithField("route", r.Stream.String()) switch { @@ -83,6 +80,7 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError { wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias)) waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream) if err != nil { + r.task.Finish(err) return err } r.Stream = waker @@ -90,24 +88,41 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError { case entry.UseHealthCheck(r): r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck) } - r.task = providerSubtask - r.task.OnComplete("stop stream", r.CloseListeners) if err := r.Setup(); err != nil { + r.task.Finish(err) return E.FailWith("setup", err) } - r.l.Infof("listening on port %d", r.Port.ListeningPort) - go r.acceptConnections() + r.task.OnFinished("close stream", func() { + if err := r.Close(); err != nil { + r.l.Error("close stream error: ", err) + } + }) + r.task.OnFinished("remove from route table", func() { + streamRoutes.Delete(string(r.Alias)) + }) + + r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort) if r.HealthMon != nil { - r.HealthMon.Start(r.task.Subtask("health monitor")) + if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { + logrus.Warn("health monitor error: ", err) + } } + + go r.acceptConnections() streamRoutes.Store(string(r.Alias), r) return nil } +func (r *StreamRoute) Finish(reason any) { + r.task.Finish(reason) +} + func (r *StreamRoute) acceptConnections() { + defer r.task.Finish("listener closed") + for { select { case <-r.task.Context().Done(): @@ -117,24 +132,17 @@ func (r *StreamRoute) acceptConnections() { if err != nil { select { case <-r.task.Context().Done(): - return default: - var nErr *stdNet.OpError - ok := errors.As(err, &nErr) - if !(ok && nErr.Timeout()) { - r.l.Error("accept connection error: ", err) - r.task.Finish(err.Error()) - return - } - continue + r.l.Error("accept connection error: ", err) + r.task.Finish(err) } + return } - connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr())) go func() { err := r.Handle(conn) if err != nil && !errors.Is(err, context.Canceled) { r.l.Error(err) - connTask.Finish(err.Error()) } else { connTask.Finish("connection closed") } diff --git a/internal/route/tcp.go b/internal/route/tcp.go index 20d378b..d14b482 100755 --- a/internal/route/tcp.go +++ b/internal/route/tcp.go @@ -1,71 +1,68 @@ package route -import ( - "context" - "fmt" - "net" - "time" +// import ( +// "context" +// "fmt" +// "net" +// "time" - "github.com/yusing/go-proxy/internal/net/types" - T "github.com/yusing/go-proxy/internal/proxy/fields" - U "github.com/yusing/go-proxy/internal/utils" - F "github.com/yusing/go-proxy/internal/utils/functional" -) +// "github.com/yusing/go-proxy/internal/net/types" +// T "github.com/yusing/go-proxy/internal/proxy/fields" +// U "github.com/yusing/go-proxy/internal/utils" +// F "github.com/yusing/go-proxy/internal/utils/functional" +// ) -const tcpDialTimeout = 5 * time.Second +// const tcpDialTimeout = 5 * time.Second -type ( - TCPConnMap = F.Map[net.Conn, struct{}] - TCPRoute struct { - *StreamRoute - listener *net.TCPListener - } -) +// type ( +// TCPConnMap = F.Map[net.Conn, struct{}] +// TCPRoute struct { +// *StreamRoute +// listener *net.TCPListener +// } +// ) -func NewTCPRoute(base *StreamRoute) *TCPRoute { - return &TCPRoute{StreamRoute: base} -} +// func NewTCPRoute(base *StreamRoute) *TCPRoute { +// return &TCPRoute{StreamRoute: base} +// } -func (route *TCPRoute) Setup() error { - in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) - if err != nil { - return err - } - //! this read the allocated port from original ':0' - route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) - route.listener = in.(*net.TCPListener) - return nil -} +// func (route *TCPRoute) Setup() error { +// var cfg net.ListenConfig +// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) +// if err != nil { +// return err +// } +// //! this read the allocated port from original ':0' +// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) +// route.listener = in.(*net.TCPListener) +// return nil +// } -func (route *TCPRoute) Accept() (types.StreamConn, error) { - route.listener.SetDeadline(time.Now().Add(time.Second)) - return route.listener.Accept() -} +// func (route *TCPRoute) Accept() (types.StreamConn, error) { +// return route.listener.Accept() +// } -func (route *TCPRoute) Handle(c types.StreamConn) error { - clientConn := c.(net.Conn) +// func (route *TCPRoute) Handle(c types.StreamConn) error { +// clientConn := c.(net.Conn) - defer clientConn.Close() - route.task.OnComplete("close conn", func() { clientConn.Close() }) +// defer clientConn.Close() +// route.task.OnCancel("close conn", func() { clientConn.Close() }) - ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) +// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) - serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) - dialer := &net.Dialer{} +// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) +// dialer := &net.Dialer{} - serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) - cancel() - if err != nil { - return err - } +// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) +// cancel() +// if err != nil { +// return err +// } - pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) - return pipe.Start() -} +// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) +// return pipe.Start() +// } -func (route *TCPRoute) CloseListeners() { - if route.listener == nil { - return - } - route.listener.Close() -} +// func (route *TCPRoute) Close() error { +// return route.listener.Close() +// } diff --git a/internal/route/udp.go b/internal/route/udp.go index 2f19f96..8f7ae89 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -1,145 +1,149 @@ package route -import ( - "errors" - "fmt" - "io" - "net" - "time" +// import ( +// "errors" +// "fmt" +// "io" +// "net" - "github.com/yusing/go-proxy/internal/net/types" - T "github.com/yusing/go-proxy/internal/proxy/fields" - U "github.com/yusing/go-proxy/internal/utils" - F "github.com/yusing/go-proxy/internal/utils/functional" -) +// "github.com/yusing/go-proxy/internal/net/types" +// T "github.com/yusing/go-proxy/internal/proxy/fields" +// U "github.com/yusing/go-proxy/internal/utils" +// F "github.com/yusing/go-proxy/internal/utils/functional" +// ) -type ( - UDPRoute struct { - *StreamRoute +// type ( +// UDPRoute struct { +// *StreamRoute - connMap UDPConnMap +// connMap UDPConnMap - listeningConn *net.UDPConn - targetAddr *net.UDPAddr - } - UDPConn struct { - key string - src *net.UDPConn - dst *net.UDPConn - U.BidirectionalPipe - } - UDPConnMap = F.Map[string, *UDPConn] -) +// listeningConn net.PacketConn +// targetAddr *net.UDPAddr +// } +// UDPConn struct { +// key string +// src net.Conn +// dst net.Conn +// U.BidirectionalPipe +// } +// UDPConnMap = F.Map[string, *UDPConn] +// ) -var NewUDPConnMap = F.NewMap[UDPConnMap] +// var NewUDPConnMap = F.NewMap[UDPConnMap] -const udpBufferSize = 8192 +// const udpBufferSize = 8192 -func NewUDPRoute(base *StreamRoute) *UDPRoute { - return &UDPRoute{ - StreamRoute: base, - connMap: NewUDPConnMap(), - } -} +// func NewUDPRoute(base *StreamRoute) *UDPRoute { +// return &UDPRoute{ +// StreamRoute: base, +// connMap: NewUDPConnMap(), +// } +// } -func (route *UDPRoute) Setup() error { - laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort)) - if err != nil { - return err - } - source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr) - if err != nil { - return err - } - raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) - if err != nil { - source.Close() - return err - } +// func (route *UDPRoute) Setup() error { +// var cfg net.ListenConfig +// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort)) +// if err != nil { +// return err +// } +// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) +// if err != nil { +// source.Close() +// return err +// } - //! this read the allocated listeningPort from original ':0' - route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) +// //! this read the allocated listeningPort from original ':0' +// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) - route.listeningConn = source - route.targetAddr = raddr +// route.listeningConn = source +// route.targetAddr = raddr - return nil -} +// return nil +// } -func (route *UDPRoute) Accept() (types.StreamConn, error) { - in := route.listeningConn +// func (route *UDPRoute) Accept() (types.StreamConn, error) { +// in := route.listeningConn - buffer := make([]byte, udpBufferSize) - route.listeningConn.SetReadDeadline(time.Now().Add(time.Second)) - nRead, srcAddr, err := in.ReadFromUDP(buffer) - if err != nil { - return nil, err - } +// buffer := make([]byte, udpBufferSize) +// nRead, srcAddr, err := in.ReadFrom(buffer) +// if err != nil { +// return nil, err +// } - if nRead == 0 { - return nil, io.ErrShortBuffer - } +// if nRead == 0 { +// return nil, io.ErrShortBuffer +// } - key := srcAddr.String() - conn, ok := route.connMap.Load(key) +// key := srcAddr.String() +// conn, ok := route.connMap.Load(key) - if !ok { - srcConn, err := net.DialUDP("udp", nil, srcAddr) - if err != nil { - return nil, err - } - dstConn, err := net.DialUDP("udp", nil, route.targetAddr) - if err != nil { - srcConn.Close() - return nil, err - } - conn = &UDPConn{ - key, - srcConn, - dstConn, - U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), - } - route.connMap.Store(key, conn) - } +// if !ok { +// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String()) +// if err != nil { +// return nil, err +// } +// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String()) +// if err != nil { +// srcConn.Close() +// return nil, err +// } +// conn = &UDPConn{ +// key, +// srcConn, +// dstConn, +// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), +// } +// route.connMap.Store(key, conn) +// } - _, err = conn.dst.Write(buffer[:nRead]) - return conn, err -} +// _, err = conn.dst.Write(buffer[:nRead]) +// return conn, err +// } -func (route *UDPRoute) Handle(c types.StreamConn) error { - conn := c.(*UDPConn) - err := conn.Start() - route.connMap.Delete(conn.key) - return err -} +// func (route *UDPRoute) Handle(c types.StreamConn) error { +// switch c := c.(type) { +// case *UDPConn: +// err := c.Start() +// route.connMap.Delete(c.key) +// c.Close() +// return err +// case *net.TCPConn: +// in := route.listeningConn +// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr)) +// if err != nil { +// return err +// } +// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start() +// c.Close() +// return err +// } +// return fmt.Errorf("unknown conn type: %T", c) +// } -func (route *UDPRoute) CloseListeners() { - if route.listeningConn != nil { - route.listeningConn.Close() - } - route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) { - if err := conn.Close(); err != nil { - route.l.Errorf("error closing conn: %s", err) - } - }) - route.connMap.Clear() -} +// func (route *UDPRoute) Close() error { +// route.connMap.RangeAllParallel(func(k string, v *UDPConn) { +// v.Close() +// }) +// route.connMap.Clear() +// return route.listeningConn.Close() +// } -// Close implements types.StreamConn -func (conn *UDPConn) Close() error { - return errors.Join(conn.src.Close(), conn.dst.Close()) -} +// // Close implements types.StreamConn +// func (conn *UDPConn) Close() error { +// return errors.Join(conn.src.Close(), conn.dst.Close()) +// } -// RemoteAddr implements types.StreamConn -func (conn *UDPConn) RemoteAddr() net.Addr { - return conn.src.RemoteAddr() -} +// // RemoteAddr implements types.StreamConn +// func (conn *UDPConn) RemoteAddr() net.Addr { +// return conn.src.RemoteAddr() +// } -type sourceRWCloser struct { - server *net.UDPConn - *net.UDPConn -} +// type sourceRWCloser struct { +// server net.PacketConn +// net.Conn +// } -func (w sourceRWCloser) Write(p []byte) (int, error) { - return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp -} +// func (w sourceRWCloser) Write(p []byte) (int, error) { +// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr)) +// } diff --git a/internal/route/udp_listener.go b/internal/route/udp_listener.go new file mode 100644 index 0000000..04f3b65 --- /dev/null +++ b/internal/route/udp_listener.go @@ -0,0 +1,73 @@ +package route + +import ( + "context" + "io" + "net" + "sync" + + F "github.com/yusing/go-proxy/internal/utils/functional" +) + +type ( + UDPListener struct { + ctx context.Context + listener net.PacketConn + connMap UDPConnMap + mu sync.Mutex + } + UDPConnMap = F.Map[string, net.Conn] +) + +var NewUDPConnMap = F.NewMap[UDPConnMap] + +func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener { + return &UDPListener{ + ctx: ctx, + listener: listener, + connMap: NewUDPConnMap(), + } +} + +// Addr implements net.Listener. +func (route *UDPListener) Addr() net.Addr { + return route.listener.LocalAddr() +} + +func (udpl *UDPListener) Accept() (net.Conn, error) { + in := udpl.listener + + buffer := make([]byte, streamBufferSize) + nRead, srcAddr, err := in.ReadFrom(buffer) + if err != nil { + return nil, err + } + + if nRead == 0 { + return nil, io.ErrShortBuffer + } + + udpl.mu.Lock() + defer udpl.mu.Unlock() + + key := srcAddr.String() + conn, ok := udpl.connMap.Load(key) + if !ok { + dialer := &net.Dialer{Timeout: streamDialTimeout} + srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String()) + if err != nil { + return nil, err + } + udpl.connMap.Store(key, srcConn) + } + return conn, nil +} + +// Close implements net.Listener. +func (route *UDPListener) Close() error { + route.connMap.RangeAllParallel(func(key string, conn net.Conn) { + conn.Close() + }) + route.connMap.Clear() + return route.listener.Close() +} diff --git a/internal/server/server.go b/internal/server/server.go index b6784e9..0d917ec 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -116,7 +116,7 @@ func (s *Server) Start() { }() } - s.task.OnComplete("stop server", s.stop) + s.task.OnFinished("stop server", s.stop) } func (s *Server) stop() { diff --git a/internal/task/task.go b/internal/task/task.go index eb37106..c245f84 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -39,10 +39,11 @@ type ( // Use Task.Finish to stop all subtasks of the task. Task interface { TaskFinisher + fmt.Stringer // Name returns the name of the task. Name() string // Context returns the context associated with the task. This context is - // canceled when Finish is called. + // canceled when Finish of the task is called, or parent task is canceled. Context() context.Context // FinishCause returns the reason / error that caused the task to be finished. FinishCause() error @@ -53,12 +54,16 @@ type ( // If the parent's context is already canceled, the returned subtask will be canceled immediately. // // This should not be called after Finish, Wait, or WaitSubTasks is called. - Subtask(usageFmt string, args ...any) Task - // OnComplete calls fn when the task and all subtasks are finished. + Subtask(name string) Task + // OnFinished calls fn when all subtasks are finished. // // It cannot be called after Finish or Wait is called. - OnComplete(about string, fn func()) - // Wait waits for all subtasks, itself and all OnComplete to finish. + OnFinished(about string, fn func()) + // OnCancel calls fn when the task is canceled. + // + // It cannot be called after Finish or Wait is called. + OnCancel(about string, fn func()) + // Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish. // // It must be called only after Finish is called. Wait() @@ -76,37 +81,46 @@ type ( // The task passed must be a subtask of the caller task. // // callerSubtask.Finish must be called when start fails or the object is finished. - Start(callerSubtask Task) E.NestedError + Start(callerSubtask Task) E.Error } TaskFinisher interface { - // Finish marks the task as finished by cancelling its context. + // Finish marks the task as finished and cancel its context. // - // Then call Wait to wait for all subtasks and OnComplete of the task to finish. + // Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished + // of the task to finish. // // Note that it will also cancel all subtasks. - Finish(reason string) + Finish(reason any) } task struct { ctx context.Context cancel context.CancelCauseFunc - parent *task - subtasks *xsync.MapOf[*task, struct{}] + parent *task + subtasks *xsync.MapOf[*task, struct{}] + subTasksWg sync.WaitGroup name, line string - subTasksWg, onCompleteWg sync.WaitGroup + OnFinishedFuncs []func() + OnFinishedMu sync.Mutex + onFinishedWg sync.WaitGroup + + finishOnce sync.Once } ) var ( ErrProgramExiting = errors.New("program exiting") - ErrTaskCancelled = errors.New("task cancelled") + ErrTaskCanceled = errors.New("task canceled") ) // GlobalTask returns a new Task with the given name, derived from the global context. func GlobalTask(format string, args ...any) Task { - return globalTask.Subtask(format, args...) + if len(args) > 0 { + format = fmt.Sprintf(format, args...) + } + return globalTask.Subtask(format) } // DebugTaskMap returns a map[string]any representation of the global task tree. @@ -155,6 +169,10 @@ func (t *task) Name() string { return t.name } +func (t *task) String() string { + return t.name +} + func (t *task) Context() context.Context { return t.ctx } @@ -171,43 +189,83 @@ func (t *task) Parent() Task { return t.parent } -func (t *task) OnComplete(about string, fn func()) { - t.onCompleteWg.Add(1) +func (t *task) runAllOnFinished(onCompTask Task) { + <-t.ctx.Done() + t.WaitSubTasks() + for _, OnFinishedFunc := range t.OnFinishedFuncs { + OnFinishedFunc() + t.onFinishedWg.Done() + } + onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done")) +} + +func (t *task) OnFinished(about string, fn func()) { + if t.parent == globalTask { + t.OnCancel(about, fn) + return + } + t.onFinishedWg.Add(1) + t.OnFinishedMu.Lock() + defer t.OnFinishedMu.Unlock() + + if t.OnFinishedFuncs == nil { + onCompTask := GlobalTask(t.name + " > OnFinished") + go t.runAllOnFinished(onCompTask) + } var file string var line int if common.IsTrace { _, file, line, _ = runtime.Caller(1) } - go func() { + idx := len(t.OnFinishedFuncs) + wrapped := func() { defer func() { if err := recover(); err != nil { - logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err) + logrus.Errorf("panic in %s > OnFinished[%d]: %q\nline %s:%d\n%v", t.name, idx, about, file, line, err) } }() - defer t.onCompleteWg.Done() - t.subTasksWg.Wait() + fn() + logrus.Tracef("line %s:%d\n%s > OnFinished[%d] done: %s", file, line, t.name, idx, about) + } + t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped) +} + +func (t *task) OnCancel(about string, fn func()) { + onCompTask := GlobalTask(t.name + " > OnFinished") + go func() { <-t.ctx.Done() fn() - logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about) - t.cancel(nil) // ensure resources are released + onCompTask.Finish("done") + logrus.Tracef("%s > onCancel done: %s", t.name, about) }() } -func (t *task) Finish(reason string) { - t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason)) - t.Wait() +func (t *task) Finish(reason any) { + var format string + switch reason.(type) { + case error: + format = "%w" + case string, fmt.Stringer: + format = "%s" + default: + format = "%v" + } + t.finishOnce.Do(func() { + t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason)) + t.Wait() + }) } -func (t *task) Subtask(format string, args ...any) Task { - if len(args) > 0 { - format = fmt.Sprintf(format, args...) - } +func (t *task) Subtask(name string) Task { ctx, cancel := context.WithCancelCause(t.ctx) - return t.newSubTask(ctx, cancel, format) + return t.newSubTask(ctx, cancel, name) } func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task { parent := t + if common.IsTrace { + name = parent.name + " > " + name + } subtask := &task{ ctx: ctx, cancel: cancel, @@ -222,10 +280,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n if ok { subtask.line = fmt.Sprintf("%s:%d", file, line) } - logrus.Tracef("line %s\ntask %q started", subtask.line, name) + logrus.Tracef("line %s\n%s started", subtask.line, name) go func() { subtask.Wait() - logrus.Tracef("task %q finished", subtask.Name()) + logrus.Tracef("%s finished: %s", subtask.Name(), subtask.FinishCause()) }() } go func() { @@ -237,11 +295,9 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n } func (t *task) Wait() { - t.subTasksWg.Wait() - if t != globalTask { - <-t.ctx.Done() - } - t.onCompleteWg.Wait() + <-t.ctx.Done() + t.WaitSubTasks() + t.onFinishedWg.Wait() } func (t *task) WaitSubTasks() { @@ -270,9 +326,9 @@ func (t *task) tree(prefix ...string) string { } if t.line != "" { sb.WriteString("line " + t.line + "\n") - } - if len(pre) > 0 { - sb.WriteString(pre + "- ") + if len(pre) > 0 { + sb.WriteString(pre + "- ") + } } sb.WriteString(t.Name() + "\n") t.subtasks.Range(func(subtask *task, _ struct{}) bool { @@ -299,7 +355,8 @@ func (t *task) tree(prefix ...string) string { // only. func (t *task) serialize() map[string]any { m := make(map[string]any) - m["name"] = t.name + parts := strings.Split(t.name, ">") + m["name"] = strings.TrimSpace(parts[len(parts)-1]) if t.line != "" { m["line"] = t.line } diff --git a/internal/task/task_test.go b/internal/task/task_test.go index 1fc08e0..47f2143 100644 --- a/internal/task/task_test.go +++ b/internal/task/task_test.go @@ -40,7 +40,7 @@ func TestTaskCancellation(t *testing.T) { err := subTask.Context().Err() ExpectError(t, context.Canceled, err) cause := context.Cause(subTask.Context()) - ExpectError(t, ErrTaskCancelled, cause) + ExpectError(t, ErrTaskCanceled, cause) case <-time.After(1 * time.Second): t.Fatal("subTask context was not canceled as expected") } @@ -74,7 +74,7 @@ func TestOnComplete(t *testing.T) { task := GlobalTask("test") var value atomic.Int32 - task.OnComplete("set value", func() { + task.OnFinished("set value", func() { value.Store(1234) }) task.Finish("done") @@ -90,10 +90,10 @@ func TestGlobalContextWait(t *testing.T) { subTask1 := rootTask.Subtask("subtask1") subTask2 := rootTask.Subtask("subtask2") - subTask1.OnComplete("set finished", func() { + subTask1.OnFinished("set finished", func() { finished1 = true }) - subTask2.OnComplete("set finished", func() { + subTask2.OnFinished("set finished", func() { finished2 = true }) @@ -117,8 +117,8 @@ func TestGlobalContextWait(t *testing.T) { ExpectTrue(t, finished1) ExpectTrue(t, finished2) ExpectError(t, context.Canceled, rootTask.Context().Err()) - ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context())) - ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context())) + ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context())) + ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context())) } func TestTimeoutOnGlobalContextWait(t *testing.T) { diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index 7cd71e9..657f405 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -160,7 +160,7 @@ func (m Map[KT, VT]) Has(k KT) bool { // Returns: // // error: if the unmarshaling fails -func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError { +func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.Error { if m.Size() != 0 { return E.FailedWhy("unmarshal from yaml", "map is not empty") } diff --git a/internal/utils/io.go b/internal/utils/io.go index 18a8773..ea79772 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -152,7 +152,7 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) } -func LoadJSON[T any](path string, pointer *T) E.NestedError { +func LoadJSON[T any](path string, pointer *T) E.Error { data, err := E.Check(os.ReadFile(path)) if err.HasError() { return err @@ -160,7 +160,7 @@ func LoadJSON[T any](path string, pointer *T) E.NestedError { return E.From(json.Unmarshal(data, pointer)) } -func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.NestedError { +func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.Error { data, err := E.Check(json.Marshal(pointer)) if err.HasError() { return err diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index db174f9..7565a5e 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -19,11 +19,11 @@ import ( type ( SerializedObject = map[string]any Converter interface { - ConvertFrom(value any) E.NestedError + ConvertFrom(value any) E.Error } ) -func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { +func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { var i any err := yaml.Unmarshal(data, &i) @@ -66,7 +66,7 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { // Returns: // - result: The resulting map[string]any representation of the data. // - error: An error if the data type is unsupported or if there is an error during conversion. -func Serialize(data any) (SerializedObject, E.NestedError) { +func Serialize(data any) (SerializedObject, E.Error) { result := make(map[string]any) // Use reflection to inspect the data type @@ -137,7 +137,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) { // If the target value is a map[string]any, the SerializedObject will be deserialized into the map. // // The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. -func Deserialize(src SerializedObject, dst any) E.NestedError { +func Deserialize(src SerializedObject, dst any) E.Error { if src == nil { return E.Invalid("src", "nil") } @@ -210,7 +210,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError { // // Returns: // - error: the error occurred during conversion, or nil if no error occurred. -func Convert(src reflect.Value, dst reflect.Value) E.NestedError { +func Convert(src reflect.Value, dst reflect.Value) E.Error { srcT := src.Type() dstT := dst.Type() @@ -277,7 +277,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError { return converter.ConvertFrom(src.Interface()) } -func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) { +func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) { convertible = true if dst.Kind() == reflect.Ptr { if dst.IsNil() { @@ -379,7 +379,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N return true, Convert(reflect.ValueOf(tmp), dst) } -func DeserializeJSON(j map[string]string, target any) E.NestedError { +func DeserializeJSON(j map[string]string, target any) E.Error { data, err := E.Check(json.Marshal(j)) if err != nil { return err diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index a87499d..632d852 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -113,7 +113,7 @@ type testType struct { bar string } -func (c *testType) ConvertFrom(v any) E.NestedError { +func (c *testType) ConvertFrom(v any) E.Error { switch v := v.(type) { case string: c.bar = v diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index d74b518..8331d0c 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -98,7 +98,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) { return got.(T) } -func Must[T any](v T, err E.NestedError) T { +func Must[T any](v T, err E.Error) T { if err != nil { panic(err) } diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index 46d833b..d729201 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -21,7 +21,7 @@ type DirWatcher struct { mu sync.Mutex eventCh chan Event - errCh chan E.NestedError + errCh chan E.Error ctx context.Context } @@ -48,14 +48,14 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher { w: w, fwMap: F.NewMapOf[string, *fileWatcher](), eventCh: make(chan Event), - errCh: make(chan E.NestedError), + errCh: make(chan E.Error), ctx: ctx, } go helper.start() return helper } -func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) { +func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.Error) { return h.eventCh, h.errCh } @@ -71,7 +71,7 @@ func (h *DirWatcher) Add(relPath string) Watcher { s = &fileWatcher{ relPath: relPath, eventCh: make(chan Event), - errCh: make(chan E.NestedError), + errCh: make(chan E.Error), } go func() { defer func() { diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index 6446860..59bbe6b 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -36,6 +36,14 @@ var ( NewDockerFilter = filters.NewArgs + optionsDefault = DockerListOptions{Filters: NewDockerFilter( + DockerFilterContainer, + DockerFilterStart, + // DockerFilterStop, + DockerFilterDie, + DockerFilterDestroy, + )} + dockerWatcherRetryInterval = 3 * time.Second ) @@ -61,13 +69,13 @@ func NewDockerWatcherWithClient(client D.Client) DockerWatcher { WithField("host", client.DaemonHost()))} } -func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) { - return w.EventsWithOptions(ctx, optionsWatchAll) +func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) { + return w.EventsWithOptions(ctx, optionsDefault) } -func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) { +func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.Error) { eventCh := make(chan Event) - errCh := make(chan E.NestedError) + errCh := make(chan E.Error) go func() { defer close(eventCh) @@ -80,7 +88,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList }() if !w.client.Connected() { - var err E.NestedError + var err E.Error attempts := 0 for { w.client, err = D.ConnectClient(w.host) @@ -141,11 +149,3 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList return eventCh, errCh } - -var optionsWatchAll = DockerListOptions{Filters: NewDockerFilter( - DockerFilterContainer, - DockerFilterStart, - // DockerFilterStop, - DockerFilterDie, - DockerFilterDestroy, -)} diff --git a/internal/watcher/events/event_queue.go b/internal/watcher/events/event_queue.go index b8891ba..992166e 100644 --- a/internal/watcher/events/event_queue.go +++ b/internal/watcher/events/event_queue.go @@ -3,20 +3,22 @@ package events import ( "time" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/task" ) type ( EventQueue struct { - task task.Task - queue []Event - ticker *time.Ticker - onFlush OnFlushFunc - onError OnErrorFunc + task task.Task + queue []Event + ticker *time.Ticker + flushInterval time.Duration + onFlush OnFlushFunc + onError OnErrorFunc } OnFlushFunc = func(flushTask task.Task, events []Event) - OnErrorFunc = func(err E.NestedError) + OnErrorFunc = func(err E.Error) ) const eventQueueCapacity = 10 @@ -35,40 +37,45 @@ const eventQueueCapacity = 10 // flushTask.Finish must be called after the flush is done, // but the onFlush function can return earlier (e.g. run in another goroutine). // -// If task is cancelled before the flushInterval is reached, the events in queue will be discarded. +// If task is canceled before the flushInterval is reached, the events in queue will be discarded. func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { return &EventQueue{ - task: parent.Subtask("event queue"), - queue: make([]Event, 0, eventQueueCapacity), - ticker: time.NewTicker(flushInterval), - onFlush: onFlush, - onError: onError, + task: parent.Subtask("event queue"), + queue: make([]Event, 0, eventQueueCapacity), + ticker: time.NewTicker(flushInterval), + flushInterval: flushInterval, + onFlush: onFlush, + onError: onError, } } -func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) { +func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { go func() { defer e.ticker.Stop() for { select { case <-e.task.Context().Done(): - e.task.Finish(e.task.FinishCause().Error()) return case <-e.ticker.C: if len(e.queue) > 0 { flushTask := e.task.Subtask("flush events") queue := e.queue e.queue = make([]Event, 0, eventQueueCapacity) - go func() { - defer func() { - if err := recover(); err != nil { - e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name())) - } + if !common.IsDebug { + go func() { + defer func() { + if err := recover(); err != nil { + e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name())) + } + }() + e.onFlush(flushTask, queue) }() - e.onFlush(flushTask, queue) - }() + } else { + go e.onFlush(flushTask, queue) + } flushTask.Wait() } + e.ticker.Reset(e.flushInterval) case event, ok := <-eventCh: e.queue = append(e.queue, event) if !ok { diff --git a/internal/watcher/file_watcher.go b/internal/watcher/file_watcher.go index 4bd310b..a53025d 100644 --- a/internal/watcher/file_watcher.go +++ b/internal/watcher/file_watcher.go @@ -9,9 +9,9 @@ import ( type fileWatcher struct { relPath string eventCh chan Event - errCh chan E.NestedError + errCh chan E.Error } -func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) { +func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) { return fw.eventCh, fw.errCh } diff --git a/internal/watcher/health/health_checker.go b/internal/watcher/health/health_checker.go new file mode 100644 index 0000000..42ac088 --- /dev/null +++ b/internal/watcher/health/health_checker.go @@ -0,0 +1,28 @@ +package health + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/task" +) + +type ( + HealthMonitor interface { + task.TaskStarter + task.TaskFinisher + fmt.Stringer + json.Marshaler + Status() Status + Uptime() time.Duration + Name() string + } + HealthChecker interface { + CheckHealth() (healthy bool, detail string, err error) + URL() types.URL + Config() *HealthCheckConfig + UpdateURL(url types.URL) + } +) diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go index 05405d1..41773ec 100644 --- a/internal/watcher/health/monitor.go +++ b/internal/watcher/health/monitor.go @@ -2,9 +2,7 @@ package health import ( "context" - "encoding/json" "errors" - "fmt" "time" E "github.com/yusing/go-proxy/internal/error" @@ -15,21 +13,6 @@ import ( ) type ( - HealthMonitor interface { - task.TaskStarter - task.TaskFinisher - fmt.Stringer - json.Marshaler - Status() Status - Uptime() time.Duration - Name() string - } - HealthChecker interface { - CheckHealth() (healthy bool, detail string, err error) - URL() types.URL - Config() *HealthCheckConfig - UpdateURL(url types.URL) - } HealthCheckFunc func() (healthy bool, detail string, err error) monitor struct { service string @@ -71,7 +54,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance } // Start implements task.TaskStarter. -func (mon *monitor) Start(routeSubtask task.Task) E.NestedError { +func (mon *monitor) Start(routeSubtask task.Task) E.Error { mon.service = routeSubtask.Parent().Name() mon.task = routeSubtask @@ -84,7 +67,6 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError { if mon.status.Load() != StatusError { mon.status.Store(StatusUnknown) } - mon.task.Finish(mon.task.FinishCause().Error()) }() if err := mon.checkUpdateHealth(); err != nil { @@ -115,7 +97,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError { } // Finish implements task.TaskFinisher. -func (mon *monitor) Finish(reason string) { +func (mon *monitor) Finish(reason any) { mon.task.Finish(reason) } @@ -169,10 +151,10 @@ func (mon *monitor) MarshalJSON() ([]byte, error) { }).MarshalJSON() } -func (mon *monitor) checkUpdateHealth() E.NestedError { +func (mon *monitor) checkUpdateHealth() E.Error { healthy, detail, err := mon.checkHealth() if err != nil { - defer mon.task.Finish(err.Error()) + defer mon.task.Finish(err) mon.status.Store(StatusError) if !errors.Is(err, context.Canceled) { return E.Failure("check health").With(err) diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index fa39951..e33cd55 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -10,5 +10,5 @@ import ( type Event = events.Event type Watcher interface { - Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) + Events(ctx context.Context) (<-chan Event, <-chan E.Error) }