fixed loadbalancer with idlewatcher, fixed reload issue

This commit is contained in:
yusing 2024-10-20 09:46:02 +08:00
parent 01ffe0d97c
commit a278711421
78 changed files with 906 additions and 609 deletions

1
.gitignore vendored
View file

@ -24,3 +24,4 @@ todo.md
.aider* .aider*
mtrace.json mtrace.json
.env .env
test.Dockerfile

View file

@ -55,7 +55,7 @@ repush:
git push gitlab dev --force git push gitlab dev --force
rapid-crash: rapid-crash:
sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\ sudo docker run --restart=always --name test_crash -p 80 debian:bookworm-slim /bin/cat &&\
sleep 3 &&\ sleep 3 &&\
sudo docker rm -f test_crash sudo docker rm -f test_crash
@ -64,4 +64,4 @@ debug-list-containers:
ci-test: ci-test:
mkdir -p /tmp/artifacts mkdir -p /tmp/artifacts
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"

View file

@ -84,7 +84,7 @@ func main() {
middleware.LoadComposeFiles() middleware.LoadComposeFiles()
var cfg *config.Config var cfg *config.Config
var err E.NestedError var err E.Error
if cfg, err = config.Load(); err != nil { if cfg, err = config.Load(); err != nil {
logrus.Warn(err) logrus.Warn(err)
} }

View file

@ -39,7 +39,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return return
} }
var validateErr E.NestedError var validateErr E.Error
if filename == common.ConfigFileName { if filename == common.ConfigFileName {
validateErr = config.Validate(content) validateErr = config.Validate(content)
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) { } else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {

View file

@ -13,7 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
) )
func ReloadServer() E.NestedError { func ReloadServer() E.Error {
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil) resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
if err != nil { if err != nil {
return E.From(err) return E.From(err)
@ -34,7 +34,7 @@ func ReloadServer() E.NestedError {
return nil return nil
} }
func List[T any](what string) (_ T, outErr E.NestedError) { func List[T any](what string) (_ T, outErr E.Error) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what)) resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
if err != nil { if err != nil {
outErr = E.From(err) outErr = E.From(err)
@ -54,14 +54,14 @@ func List[T any](what string) (_ T, outErr E.NestedError) {
return res, nil return res, nil
} }
func ListRoutes() (map[string]map[string]any, E.NestedError) { func ListRoutes() (map[string]map[string]any, E.Error) {
return List[map[string]map[string]any](v1.ListRoutes) return List[map[string]map[string]any](v1.ListRoutes)
} }
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) { func ListMiddlewareTraces() (middleware.Traces, E.Error) {
return List[middleware.Traces](v1.ListMiddlewareTraces) return List[middleware.Traces](v1.ListMiddlewareTraces)
} }
func DebugListTasks() (map[string]any, E.NestedError) { func DebugListTasks() (map[string]any, E.Error) {
return List[map[string]any](v1.ListTasks) return List[map[string]any](v1.ListTasks)
} }

View file

@ -27,7 +27,7 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
return (*Config)(cfg) return (*Config)(cfg)
} }
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) { func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
b := E.NewBuilder("unable to initialize autocert") b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res) defer b.To(&res)

View file

@ -29,7 +29,7 @@ type (
tlsCert *tls.Certificate tlsCert *tls.Certificate
certExpiries CertExpiries certExpiries CertExpiries
} }
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError) ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
CertExpiries map[string]time.Time CertExpiries map[string]time.Time
) )
@ -57,7 +57,7 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries return p.certExpiries
} }
func (p *Provider) ObtainCert() (res E.NestedError) { func (p *Provider) ObtainCert() (res E.Error) {
b := E.NewBuilder("failed to obtain certificate") b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res) defer b.To(&res)
@ -112,7 +112,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
return nil return nil
} }
func (p *Provider) LoadCert() E.NestedError { func (p *Provider) LoadCert() E.Error {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)) cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.HasError() { if err.HasError() {
return err return err
@ -158,7 +158,7 @@ func (p *Provider) ScheduleRenewal() {
}() }()
} }
func (p *Provider) initClient() E.NestedError { func (p *Provider) initClient() E.Error {
legoClient, err := E.Check(lego.NewClient(p.legoCfg)) legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() { if err.HasError() {
return E.FailWith("create lego client", err) return E.FailWith("create lego client", err)
@ -178,7 +178,7 @@ func (p *Provider) initClient() E.NestedError {
return nil return nil
} }
func (p *Provider) registerACME() E.NestedError { func (p *Provider) registerACME() E.Error {
if p.user.Registration != nil { if p.user.Registration != nil {
return nil return nil
} }
@ -191,7 +191,7 @@ func (p *Provider) registerACME() E.NestedError {
return nil return nil
} }
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { func (p *Provider) saveCert(cert *certificate.Resource) E.Error {
/* This should have been done in setup /* This should have been done in setup
but double check is always a good choice.*/ but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath)) _, err := os.Stat(path.Dir(p.cfg.CertPath))
@ -239,7 +239,7 @@ func (p *Provider) certState() CertState {
return CertStateValid return CertStateValid
} }
func (p *Provider) renewIfNeeded() E.NestedError { func (p *Provider) renewIfNeeded() E.Error {
if p.cfg.Provider == ProviderLocal { if p.cfg.Provider == ProviderLocal {
return nil return nil
} }
@ -259,7 +259,7 @@ func (p *Provider) renewIfNeeded() E.NestedError {
return nil return nil
} }
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) {
r := make(CertExpiries, len(cert.Certificate)) r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate { for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert)) x509Cert, err := E.Check(x509.ParseCertificate(cert))
@ -281,7 +281,7 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT, defaultCfg func() *CT,
newProvider func(*CT) (PT, error), newProvider func(*CT) (PT, error),
) ProviderGenerator { ) ProviderGenerator {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.NestedError) { return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg() cfg := defaultCfg()
err := U.Deserialize(opt, cfg) err := U.Deserialize(opt, cfg)
if err.HasError() { if err.HasError() {

View file

@ -6,7 +6,7 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
func (p *Provider) Setup() (err E.NestedError) { func (p *Provider) Setup() (err E.Error) {
if err = p.LoadCert(); err != nil { if err = p.LoadCert(); err != nil {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err return err

View file

@ -47,3 +47,5 @@ const (
StopTimeoutDefault = "10s" StopTimeoutDefault = "10s"
StopMethodDefault = "stop" StopMethodDefault = "stop"
) )
const HeaderCheckRedirect = "X-Goproxy-Check-Redirect"

View file

@ -55,7 +55,7 @@ func newConfig() *Config {
} }
} }
func Load() (*Config, E.NestedError) { func Load() (*Config, E.Error) {
if instance != nil { if instance != nil {
return instance, nil return instance, nil
} }
@ -64,7 +64,7 @@ func Load() (*Config, E.NestedError) {
return instance, instance.load() return instance, instance.load()
} }
func Validate(data []byte) E.NestedError { func Validate(data []byte) E.Error {
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data) return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
} }
@ -78,7 +78,7 @@ func WatchChanges() {
task, task,
configEventFlushInterval, configEventFlushInterval,
OnConfigChange, OnConfigChange,
func(err E.NestedError) { func(err E.Error) {
logger.Error(err) logger.Error(err)
}, },
) )
@ -104,7 +104,7 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) {
} }
} }
func Reload() E.NestedError { func Reload() E.Error {
// avoid race between config change and API reload request // avoid race between config change and API reload request
reloadMu.Lock() reloadMu.Lock()
defer reloadMu.Unlock() defer reloadMu.Unlock()
@ -139,7 +139,7 @@ func (cfg *Config) Task() task.Task {
func (cfg *Config) StartProxyProviders() { func (cfg *Config) StartProxyProviders() {
b := E.NewBuilder("errors starting providers") b := E.NewBuilder("errors starting providers")
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName()))) b.Add(p.Start(cfg.task.Subtask(p.String())))
}) })
if b.HasError() { if b.HasError() {
@ -147,7 +147,7 @@ func (cfg *Config) StartProxyProviders() {
} }
} }
func (cfg *Config) load() (res E.NestedError) { func (cfg *Config) load() (res E.Error) {
b := E.NewBuilder("errors loading config") b := E.NewBuilder("errors loading config")
defer b.To(&res) defer b.To(&res)
@ -182,7 +182,7 @@ func (cfg *Config) load() (res E.NestedError) {
return return
} }
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.NestedError) { func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
if cfg.autocertProvider != nil { if cfg.autocertProvider != nil {
return return
} }
@ -197,7 +197,7 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
return return
} }
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) { func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.Error) {
subtask := cfg.task.Subtask("load providers") subtask := cfg.task.Subtask("load providers")
defer subtask.Finish("done") defer subtask.Finish("done")

View file

@ -37,7 +37,7 @@ var (
) )
func init() { func init() {
task.GlobalTask("close docker clients").OnComplete("", func() { task.GlobalTask("close docker clients").OnFinished("", func() {
clientMap.RangeAllParallel(func(_ string, c Client) { clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() { if c.Connected() {
c.Client.Close() c.Client.Close()
@ -70,7 +70,7 @@ func (c *SharedClient) Close() error {
// Returns: // Returns:
// - Client: the Docker client connection. // - Client: the Docker client connection.
// - error: an error if the connection failed. // - error: an error if the connection failed.
func ConnectClient(host string) (Client, E.NestedError) { func ConnectClient(host string) (Client, E.Error) {
clientMapMu.Lock() clientMapMu.Lock()
defer clientMapMu.Unlock() defer clientMapMu.Unlock()

View file

@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"text/template" "text/template"
"github.com/yusing/go-proxy/internal/common"
) )
type templateData struct { type templateData struct {
@ -18,17 +20,15 @@ type templateData struct {
var loadingPage []byte var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage))) var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
const headerCheckRedirect = "X-Goproxy-Check-Redirect"
func (w *Watcher) makeLoadingPageBody() []byte { func (w *Watcher) makeLoadingPageBody() []byte {
msg := fmt.Sprintf("%s is starting...", w.ContainerName) msg := fmt.Sprintf("%s is starting...", w.ContainerName)
data := new(templateData) data := new(templateData)
data.CheckRedirectHeader = headerCheckRedirect data.CheckRedirectHeader = common.HeaderCheckRedirect
data.Title = w.ContainerName data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, " ", " ") data.Message = strings.ReplaceAll(msg, " ", " ")
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect))) buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(common.HeaderCheckRedirect)))
err := loadingPageTmpl.Execute(buf, data) err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen in production if err != nil { // should never happen in production
panic(err) panic(err)

View file

@ -1,4 +1,4 @@
package idlewatcher package types
import ( import (
"time" "time"
@ -30,7 +30,7 @@ const (
StopMethodKill StopMethod = "kill" StopMethodKill StopMethod = "kill"
) )
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) { func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
if cont == nil { if cont == nil {
return nil, nil return nil, nil
} }
@ -80,7 +80,7 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
}, nil }, nil
} }
func validateDurationPostitive(value string) (time.Duration, E.NestedError) { func validateDurationPostitive(value string) (time.Duration, E.Error) {
d, err := time.ParseDuration(value) d, err := time.ParseDuration(value)
if err != nil { if err != nil {
return 0, E.Invalid("duration", value).With(err) return 0, E.Invalid("duration", value).With(err)
@ -91,7 +91,7 @@ func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
return d, nil return d, nil
} }
func validateSignal(s string) (Signal, E.NestedError) { func validateSignal(s string) (Signal, E.Error) {
switch s { switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT": "INT", "TERM", "HUP", "QUIT":
@ -101,7 +101,7 @@ func validateSignal(s string) (Signal, E.NestedError) {
return "", E.Invalid("signal", s) return "", E.Invalid("signal", s)
} }
func validateStopMethod(s string) (StopMethod, E.NestedError) { func validateStopMethod(s string) (StopMethod, E.Error) {
sm := StopMethod(s) sm := StopMethod(s)
switch sm { switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill: case StopMethodPause, StopMethodStop, StopMethodKill:

View file

@ -0,0 +1,14 @@
package types
import (
"net/http"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Waker interface {
health.HealthMonitor
http.Handler
net.Stream
}

View file

@ -1,10 +1,10 @@
package idlewatcher package idlewatcher
import ( import (
"net/http"
"sync/atomic" "sync/atomic"
"time" "time"
. "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
@ -14,12 +14,6 @@ import (
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
type Waker interface {
health.HealthMonitor
http.Handler
net.Stream
}
type waker struct { type waker struct {
_ U.NoCopy _ U.NoCopy
@ -37,7 +31,7 @@ const (
// TODO: support stream // TODO: support stream
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) { func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
hcCfg := entry.HealthCheckConfig() hcCfg := entry.HealthCheckConfig()
hcCfg.Timeout = idleWakerCheckTimeout hcCfg.Timeout = idleWakerCheckTimeout
@ -62,24 +56,26 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
} }
// lifetime should follow route provider // lifetime should follow route provider
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) { func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(providerSubTask, entry, rp, nil) return newWaker(providerSubTask, entry, rp, nil)
} }
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) { func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.Error) {
return newWaker(providerSubTask, entry, nil, stream) return newWaker(providerSubTask, entry, nil, stream)
} }
// Start implements health.HealthMonitor. // Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask task.Task) E.NestedError { func (w *Watcher) Start(routeSubTask task.Task) E.Error {
w.task.OnComplete("stop route", func() { routeSubTask.Finish("ignored")
routeSubTask.Parent().Finish("watcher stopped") w.task.OnCancel("stop route", func() {
routeSubTask.Parent().Finish(w.task.FinishCause())
}) })
return nil return nil
} }
// Finish implements health.HealthMonitor. // Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason string) {} func (w *Watcher) Finish(reason any) {
}
// Name implements health.HealthMonitor. // Name implements health.HealthMonitor.
func (w *Watcher) Name() string { func (w *Watcher) Name() string {
@ -109,6 +105,7 @@ func (w *Watcher) Status() health.Status {
healthy, _, err := w.hc.CheckHealth() healthy, _, err := w.hc.CheckHealth()
switch { switch {
case err != nil: case err != nil:
w.ready.Store(false)
return health.StatusError return health.StatusError
case healthy: case healthy:
w.ready.Store(true) w.ready.Store(true)

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
@ -37,7 +38,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
accept := gphttp.GetAccept(r.Header) accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != "" isCheckRedirect := r.Header.Get(common.HeaderCheckRedirect) != ""
if !isCheckRedirect && acceptHTML { if !isCheckRedirect && acceptHTML {
// Send a loading response to the client // Send a loading response to the client
body := w.makeLoadingPageBody() body := w.makeLoadingPageBody()
@ -56,14 +57,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout")) ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel() defer cancel()
checkCancelled := func() bool { checkCanceled := func() bool {
select { select {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context())) w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context()))
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true return true
case <-ctx.Done(): case <-ctx.Done():
w.l.Debugf("wake cancelled: %s", context.Cause(ctx)) w.l.Debugf("wake canceled: %s", context.Cause(ctx))
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return true return true
default: default:
@ -71,7 +72,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
} }
} }
if checkCancelled() { if checkCanceled() {
return false return false
} }
@ -84,14 +85,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
} }
for { for {
if checkCancelled() { if checkCanceled() {
return false return false
} }
if w.Status() == health.StatusHealthy { if w.Status() == health.StatusHealthy {
w.resetIdleTimer() w.resetIdleTimer()
if isCheckRedirect { if isCheckRedirect {
logrus.Debugf("container %s is ready, redirecting...", w.String()) logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL())
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
return return
} }

View file

@ -8,44 +8,47 @@ import (
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
// Setup implements types.Stream. // Setup implements types.Stream.
func (w *Watcher) Addr() net.Addr {
return w.stream.Addr()
}
func (w *Watcher) Setup() error { func (w *Watcher) Setup() error {
return w.stream.Setup() return w.stream.Setup()
} }
// Accept implements types.Stream. // Accept implements types.Stream.
func (w *Watcher) Accept() (conn types.StreamConn, err error) { func (w *Watcher) Accept() (conn net.Conn, err error) {
conn, err = w.stream.Accept() conn, err = w.stream.Accept()
// timeout means no connection is accepted if err != nil {
var nErr *net.OpError logrus.Errorf("accept failed with error: %s", err)
ok := errors.As(err, &nErr)
if ok && nErr.Timeout() {
return return
} }
if err := w.wakeFromStream(); err != nil { if err := w.wakeFromStream(); err != nil {
return nil, err w.l.Error(err)
} }
return w.stream.Accept() return
}
// CloseListeners implements types.Stream.
func (w *Watcher) CloseListeners() {
w.stream.CloseListeners()
} }
// Handle implements types.Stream. // Handle implements types.Stream.
func (w *Watcher) Handle(conn types.StreamConn) error { func (w *Watcher) Handle(conn net.Conn) error {
if err := w.wakeFromStream(); err != nil { if err := w.wakeFromStream(); err != nil {
return err return err
} }
return w.stream.Handle(conn) return w.stream.Handle(conn)
} }
// Close implements types.Stream.
func (w *Watcher) Close() error {
return w.stream.Close()
}
func (w *Watcher) wakeFromStream() error { func (w *Watcher) wakeFromStream() error {
w.resetIdleTimer()
// pass through if container is already ready // pass through if container is already ready
if w.ready.Load() { if w.ready.Load() {
return nil return nil
@ -66,11 +69,11 @@ func (w *Watcher) wakeFromStream() error {
select { select {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
cause := w.task.FinishCause() cause := w.task.FinishCause()
w.l.Debugf("wake cancelled: %s", cause) w.l.Debugf("wake canceled: %s", cause)
return cause return cause
case <-ctx.Done(): case <-ctx.Done():
cause := context.Cause(ctx) cause := context.Cause(ctx)
w.l.Debugf("wake cancelled: %s", cause) w.l.Debugf("wake canceled: %s", cause)
return cause return cause
default: default:
} }

View file

@ -10,7 +10,7 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -49,7 +49,7 @@ var (
const dockerReqTimeout = 3 * time.Second const dockerReqTimeout = 3 * time.Second
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) { func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) {
failure := E.Failure("idle_watcher register") failure := E.Failure("idle_watcher register")
cfg := entry.IdlewatcherConfig() cfg := entry.IdlewatcherConfig()
@ -66,6 +66,7 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
w.Config = cfg w.Config = cfg
w.waker = waker w.waker = waker
w.resetIdleTimer() w.resetIdleTimer()
providerSubtask.Finish("used existing watcher")
return w, nil return w, nil
} }
@ -88,13 +89,11 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
go func() { go func() {
cause := w.watchUntilDestroy() cause := w.watchUntilDestroy()
watcherMapMu.Lock()
watcherMap.Delete(w.ContainerID) watcherMap.Delete(w.ContainerID)
watcherMapMu.Unlock()
w.ticker.Stop() w.ticker.Stop()
w.client.Close() w.client.Close()
w.task.Finish(cause.Error()) w.task.Finish(cause)
}() }()
return w, nil return w, nil
@ -146,7 +145,7 @@ func (w *Watcher) wakeIfStopped() error {
return err return err
} }
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout) ctx, cancel := context.WithTimeout(w.task.Context(), w.WakeTimeout)
defer cancel() defer cancel()
// !Hard coded here since theres no constants from Docker API // !Hard coded here since theres no constants from Docker API
@ -175,7 +174,7 @@ func (w *Watcher) getStopCallback() StopCallback {
panic("should not reach here") panic("should not reach here")
} }
return func() error { return func() error {
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout) ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.StopTimeout)*time.Second)
defer cancel() defer cancel()
return cb(ctx) return cb(ctx)
} }
@ -186,8 +185,8 @@ func (w *Watcher) resetIdleTimer() {
w.ticker.Reset(w.IdleTimeout) w.ticker.Reset(w.IdleTimeout)
} }
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) { func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
eventTask = w.task.Subtask("watcher for %s", w.ContainerID) eventTask = w.task.Subtask("docker event watcher")
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter( Filters: W.NewDockerFilter(
W.DockerFilterContainer, W.DockerFilterContainer,
@ -218,13 +217,12 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
func (w *Watcher) watchUntilDestroy() error { func (w *Watcher) watchUntilDestroy() error {
dockerWatcher := W.NewDockerWatcherWithClient(w.client) dockerWatcher := W.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
defer eventTask.Finish("stopped")
for { for {
select { select {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
cause := context.Cause(w.task.Context()) return w.task.FinishCause()
w.l.Debugf("watcher stopped by context done: %s", cause)
return cause
case err := <-dockerEventErrCh: case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) { if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err)) w.l.Error(E.FailWith("docker watcher", err))

View file

@ -8,7 +8,7 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) { func Inspect(dockerHost string, containerID string) (*Container, E.Error) {
client, err := ConnectClient(dockerHost) client, err := ConnectClient(dockerHost)
defer client.Close() defer client.Close()
@ -19,7 +19,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError)
return client.Inspect(containerID) return client.Inspect(containerID)
} }
func (c Client) Inspect(containerID string) (*Container, E.NestedError) { func (c Client) Inspect(containerID string) (*Container, E.Error) {
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout")) ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel() defer cancel()

View file

@ -39,7 +39,7 @@ func (l *Label) String() string {
// //
// Returns: // Returns:
// - error: an error if the field does not exist. // - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.NestedError { func ApplyLabel[T any](obj *T, l *Label) E.Error {
if obj == nil { if obj == nil {
return E.Invalid("nil object", l) return E.Invalid("nil object", l)
} }
@ -81,7 +81,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
} }
} }
func ParseLabel(label string, value string) (*Label, E.NestedError) { func ParseLabel(label string, value string) (*Label, E.Error) {
parts := strings.Split(label, ".") parts := strings.Split(label, ".")
if len(parts) < 2 { if len(parts) < 2 {

View file

@ -11,11 +11,6 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
type ClientInfo struct {
Client Client
Containers []types.Container
}
var listOptions = container.ListOptions{ var listOptions = container.ListOptions{
// created|restarting|running|removing|paused|exited|dead // created|restarting|running|removing|paused|exited|dead
// Filters: filters.NewArgs( // Filters: filters.NewArgs(
@ -28,28 +23,21 @@ var listOptions = container.ListOptions{
All: true, All: true,
} }
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) { func ListContainers(clientHost string) ([]types.Container, E.Error) {
dockerClient, err := ConnectClient(clientHost) dockerClient, err := ConnectClient(clientHost)
if err.HasError() { if err.HasError() {
return nil, E.FailWith("connect to docker", err) return nil, E.FailWith("connect to docker", err)
} }
defer dockerClient.Close() defer dockerClient.Close()
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout")) ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
defer cancel() defer cancel()
var containers []types.Container containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions))
if getContainer { if err.HasError() {
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions)) return nil, E.FailWith("list containers", err)
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
} }
return containers, nil
return &ClientInfo{
Client: dockerClient,
Containers: containers,
}, nil
} }
func IsErrConnectionFailed(err error) bool { func IsErrConnectionFailed(err error) bool {

View file

@ -12,7 +12,7 @@ type Builder struct {
type builder struct { type builder struct {
message string message string
errors []NestedError errors []Error
sync.Mutex sync.Mutex
} }
@ -28,7 +28,7 @@ func NewBuilder(format string, args ...any) Builder {
// adding nil is no-op, // adding nil is no-op,
// //
// flatten is a boolean flag to flatten the NestedError. // flatten is a boolean flag to flatten the NestedError.
func (b Builder) Add(err NestedError, flatten ...bool) { func (b Builder) Add(err Error, flatten ...bool) {
if err != nil { if err != nil {
b.Lock() b.Lock()
if len(flatten) > 0 && flatten[0] { if len(flatten) > 0 && flatten[0] {
@ -54,7 +54,7 @@ func (b Builder) Addf(format string, args ...any) {
} }
} }
func (b Builder) AddRange(errs ...NestedError) { func (b Builder) AddRange(errs ...Error) {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
for _, err := range errs { for _, err := range errs {
@ -77,14 +77,14 @@ func (b Builder) AddRangeE(errs ...error) {
// //
// Returns: // Returns:
// - NestedError: the built NestedError. // - NestedError: the built NestedError.
func (b Builder) Build() NestedError { func (b Builder) Build() Error {
if len(b.errors) == 0 { if len(b.errors) == 0 {
return nil return nil
} }
return Join(b.message, b.errors...) return Join(b.message, b.errors...)
} }
func (b Builder) To(ptr *NestedError) { func (b Builder) To(ptr *Error) {
switch { switch {
case ptr == nil: case ptr == nil:
return return

View file

@ -16,7 +16,7 @@ func TestBuilderEmpty(t *testing.T) {
func TestBuilderAddNil(t *testing.T) { func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf") eb := NewBuilder("asdf")
var err NestedError var err Error
for range 3 { for range 3 {
eb.Add(nil) eb.Add(nil)
} }
@ -53,7 +53,7 @@ func TestBuilderTo(t *testing.T) {
eb := NewBuilder("error occurred") eb := NewBuilder("error occurred")
eb.Addf("abcd") eb.Addf("abcd")
var err NestedError var err Error
eb.To(&err) eb.To(&err)
got := err.String() got := err.String()
expected := (`error occurred: expected := (`error occurred:

View file

@ -8,35 +8,35 @@ import (
) )
type ( type (
NestedError = *NestedErrorImpl Error = *ErrorImpl
NestedErrorImpl struct { ErrorImpl struct {
subject string subject string
err error err error
extras []NestedErrorImpl extras []ErrorImpl
} }
JSONNestedError struct { ErrorJSONMarshaller struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Err string `json:"error"` Err string `json:"error"`
Extras []JSONNestedError `json:"extras,omitempty"` Extras []ErrorJSONMarshaller `json:"extras,omitempty"`
} }
) )
func From(err error) NestedError { func From(err error) Error {
if IsNil(err) { if IsNil(err) {
return nil return nil
} }
return &NestedErrorImpl{err: err} return &ErrorImpl{err: err}
} }
func FromJSON(data []byte) (NestedError, bool) { func FromJSON(data []byte) (Error, bool) {
var j JSONNestedError var j ErrorJSONMarshaller
if err := json.Unmarshal(data, &j); err != nil { if err := json.Unmarshal(data, &j); err != nil {
return nil, false return nil, false
} }
if j.Err == "" { if j.Err == "" {
return nil, false return nil, false
} }
extras := make([]NestedErrorImpl, len(j.Extras)) extras := make([]ErrorImpl, len(j.Extras))
for i, e := range j.Extras { for i, e := range j.Extras {
extra, ok := fromJSONObject(e) extra, ok := fromJSONObject(e)
if !ok { if !ok {
@ -44,7 +44,7 @@ func FromJSON(data []byte) (NestedError, bool) {
} }
extras[i] = *extra extras[i] = *extra
} }
return &NestedErrorImpl{ return &ErrorImpl{
subject: j.Subject, subject: j.Subject,
err: errors.New(j.Err), err: errors.New(j.Err),
extras: extras, extras: extras,
@ -53,12 +53,12 @@ func FromJSON(data []byte) (NestedError, bool) {
// Check is a helper function that // Check is a helper function that
// convert (T, error) to (T, NestedError). // convert (T, error) to (T, NestedError).
func Check[T any](obj T, err error) (T, NestedError) { func Check[T any](obj T, err error) (T, Error) {
return obj, From(err) return obj, From(err)
} }
func Join(message string, err ...NestedError) NestedError { func Join(message string, err ...Error) Error {
extras := make([]NestedErrorImpl, len(err)) extras := make([]ErrorImpl, len(err))
nErr := 0 nErr := 0
for i, e := range err { for i, e := range err {
if e == nil { if e == nil {
@ -70,13 +70,13 @@ func Join(message string, err ...NestedError) NestedError {
if nErr == 0 { if nErr == 0 {
return nil return nil
} }
return &NestedErrorImpl{ return &ErrorImpl{
err: errors.New(message), err: errors.New(message),
extras: extras, extras: extras,
} }
} }
func JoinE(message string, err ...error) NestedError { func JoinE(message string, err ...error) Error {
b := NewBuilder("%s", message) b := NewBuilder("%s", message)
for _, e := range err { for _, e := range err {
b.AddE(e) b.AddE(e)
@ -92,13 +92,13 @@ func IsNotNil(err error) bool {
return err != nil return err != nil
} }
func (ne NestedError) String() string { func (ne Error) String() string {
var buf strings.Builder var buf strings.Builder
ne.writeToSB(&buf, 0, "") ne.writeToSB(&buf, 0, "")
return buf.String() return buf.String()
} }
func (ne NestedError) Is(err error) bool { func (ne Error) Is(err error) bool {
if ne == nil { if ne == nil {
return err == nil return err == nil
} }
@ -114,18 +114,18 @@ func (ne NestedError) Is(err error) bool {
return false return false
} }
func (ne NestedError) IsNot(err error) bool { func (ne Error) IsNot(err error) bool {
return !ne.Is(err) return !ne.Is(err)
} }
func (ne NestedError) Error() error { func (ne Error) Error() error {
if ne == nil { if ne == nil {
return nil return nil
} }
return ne.buildError(0, "") return ne.buildError(0, "")
} }
func (ne NestedError) With(s any) NestedError { func (ne Error) With(s any) Error {
if ne == nil { if ne == nil {
return ne return ne
} }
@ -133,7 +133,7 @@ func (ne NestedError) With(s any) NestedError {
switch ss := s.(type) { switch ss := s.(type) {
case nil: case nil:
return ne return ne
case *NestedErrorImpl: case *ErrorImpl:
if len(ss.extras) == 1 { if len(ss.extras) == 1 {
ne.extras = append(ne.extras, ss.extras[0]) ne.extras = append(ne.extras, ss.extras[0])
return ne return ne
@ -151,11 +151,11 @@ func (ne NestedError) With(s any) NestedError {
return ne.withError(From(errors.New(msg))) return ne.withError(From(errors.New(msg)))
} }
func (ne NestedError) Extraf(format string, args ...any) NestedError { func (ne Error) Extraf(format string, args ...any) Error {
return ne.With(errorf(format, args...)) return ne.With(errorf(format, args...))
} }
func (ne NestedError) Subject(s any, sep ...string) NestedError { func (ne Error) Subject(s any, sep ...string) Error {
if ne == nil { if ne == nil {
return ne return ne
} }
@ -179,26 +179,26 @@ func (ne NestedError) Subject(s any, sep ...string) NestedError {
return ne return ne
} }
func (ne NestedError) Subjectf(format string, args ...any) NestedError { func (ne Error) Subjectf(format string, args ...any) Error {
if ne == nil { if ne == nil {
return ne return ne
} }
return ne.Subject(fmt.Sprintf(format, args...)) return ne.Subject(fmt.Sprintf(format, args...))
} }
func (ne NestedError) JSONObject() JSONNestedError { func (ne Error) JSONObject() ErrorJSONMarshaller {
extras := make([]JSONNestedError, len(ne.extras)) extras := make([]ErrorJSONMarshaller, len(ne.extras))
for i, e := range ne.extras { for i, e := range ne.extras {
extras[i] = e.JSONObject() extras[i] = e.JSONObject()
} }
return JSONNestedError{ return ErrorJSONMarshaller{
Subject: ne.subject, Subject: ne.subject,
Err: ne.err.Error(), Err: ne.err.Error(),
Extras: extras, Extras: extras,
} }
} }
func (ne NestedError) JSON() []byte { func (ne Error) JSON() []byte {
b, err := json.MarshalIndent(ne.JSONObject(), "", " ") b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
if err != nil { if err != nil {
panic(err) panic(err)
@ -206,19 +206,19 @@ func (ne NestedError) JSON() []byte {
return b return b
} }
func (ne NestedError) NoError() bool { func (ne Error) NoError() bool {
return ne == nil return ne == nil
} }
func (ne NestedError) HasError() bool { func (ne Error) HasError() bool {
return ne != nil return ne != nil
} }
func errorf(format string, args ...any) NestedError { func errorf(format string, args ...any) Error {
return From(fmt.Errorf(format, args...)) return From(fmt.Errorf(format, args...))
} }
func fromJSONObject(obj JSONNestedError) (NestedError, bool) { func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) {
data, err := json.Marshal(obj) data, err := json.Marshal(obj)
if err != nil { if err != nil {
return nil, false return nil, false
@ -226,14 +226,14 @@ func fromJSONObject(obj JSONNestedError) (NestedError, bool) {
return FromJSON(data) return FromJSON(data)
} }
func (ne NestedError) withError(err NestedError) NestedError { func (ne Error) withError(err Error) Error {
if ne != nil && err != nil { if ne != nil && err != nil {
ne.extras = append(ne.extras, *err) ne.extras = append(ne.extras, *err)
} }
return ne return ne
} }
func (ne NestedError) appendMsg(msg string) NestedError { func (ne Error) appendMsg(msg string) Error {
if ne == nil { if ne == nil {
return nil return nil
} }
@ -241,7 +241,7 @@ func (ne NestedError) appendMsg(msg string) NestedError {
return ne return ne
} }
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) {
for range level { for range level {
sb.WriteString(" ") sb.WriteString(" ")
} }
@ -266,7 +266,7 @@ func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
} }
} }
func (ne NestedError) buildError(level int, prefix string) error { func (ne Error) buildError(level int, prefix string) error {
var res error var res error
var sb strings.Builder var sb strings.Builder

View file

@ -26,7 +26,7 @@ func TestErrorIs(t *testing.T) {
} }
func TestErrorNestedIs(t *testing.T) { func TestErrorNestedIs(t *testing.T) {
var err NestedError var err Error
ExpectTrue(t, err.Is(nil)) ExpectTrue(t, err.Is(nil))
err = Failure("some reason") err = Failure("some reason")
@ -40,7 +40,7 @@ func TestErrorNestedIs(t *testing.T) {
} }
func TestIsNil(t *testing.T) { func TestIsNil(t *testing.T) {
var err NestedError var err Error
ExpectTrue(t, err.Is(nil)) ExpectTrue(t, err.Is(nil))
ExpectTrue(t, err == nil) ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError()) ExpectTrue(t, err.NoError())

View file

@ -22,62 +22,62 @@ var (
const fmtSubjectWhat = "%w %v: %q" const fmtSubjectWhat = "%w %v: %q"
func Failure(what string) NestedError { func Failure(what string) Error {
return errorf("%s %w", what, ErrFailure) return errorf("%s %w", what, ErrFailure)
} }
func FailedWhy(what string, why string) NestedError { func FailedWhy(what string, why string) Error {
return Failure(what).With(why) return Failure(what).With(why)
} }
func FailWith(what string, err any) NestedError { func FailWith(what string, err any) Error {
return Failure(what).With(err) return Failure(what).With(err)
} }
func Invalid(subject, what any) NestedError { func Invalid(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrInvalid, subject, what) return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
} }
func Unsupported(subject, what any) NestedError { func Unsupported(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what) return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
} }
func Unexpected(subject, what any) NestedError { func Unexpected(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what) return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
} }
func UnexpectedError(err error) NestedError { func UnexpectedError(err error) Error {
return errorf("%w error: %w", ErrUnexpected, err) return errorf("%w error: %w", ErrUnexpected, err)
} }
func NotExist(subject, what any) NestedError { func NotExist(subject, what any) Error {
return errorf("%v %w: %v", subject, ErrNotExists, what) return errorf("%v %w: %v", subject, ErrNotExists, what)
} }
func Missing(subject any) NestedError { func Missing(subject any) Error {
return errorf("%w %v", ErrMissing, subject) return errorf("%w %v", ErrMissing, subject)
} }
func Duplicated(subject, what any) NestedError { func Duplicated(subject, what any) Error {
return errorf("%w %v: %v", ErrDuplicated, subject, what) return errorf("%w %v: %v", ErrDuplicated, subject, what)
} }
func OutOfRange(subject any, value any) NestedError { func OutOfRange(subject any, value any) Error {
return errorf("%v %w: %v", subject, ErrOutOfRange, value) return errorf("%v %w: %v", subject, ErrOutOfRange, value)
} }
func TypeError(subject any, from, to reflect.Type) NestedError { func TypeError(subject any, from, to reflect.Type) Error {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to) return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
} }
func TypeError2(subject any, from, to reflect.Value) NestedError { func TypeError2(subject any, from, to reflect.Value) Error {
return TypeError(subject, from.Type(), to.Type()) return TypeError(subject, from.Type(), to.Type())
} }
func TypeMismatch[Expect any](value any) NestedError { func TypeMismatch[Expect any](value any) Error {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
} }
func PanicRecv(format string, args ...any) NestedError { func PanicRecv(format string, args ...any) Error {
return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...)) return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...))
} }

View file

@ -0,0 +1,15 @@
package loadbalancer
import "net/http"
type DummyResponseWriter struct{}
func (w *DummyResponseWriter) Header() (_ http.Header) {
return
}
func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w *DummyResponseWriter) WriteHeader(int) {}

View file

@ -21,7 +21,7 @@ func (lb *LoadBalancer) newIPHash() impl {
if len(lb.Options) == 0 { if len(lb.Options) == 0 {
return impl return impl
} }
var err E.NestedError var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options) impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil { if err != nil {
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err) logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)

View file

@ -1,10 +1,13 @@
package loadbalancer package loadbalancer
import ( import (
"context"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/yusing/go-proxy/internal/common"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -54,10 +57,10 @@ func New(cfg *Config) *LoadBalancer {
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError { func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error {
lb.startTime = time.Now() lb.startTime = time.Now()
lb.task = routeSubtask lb.task = routeSubtask
lb.task.OnComplete("loadbalancer cleanup", func() { lb.task.OnFinished("loadbalancer cleanup", func() {
if lb.impl != nil { if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) { lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v) lb.impl.OnRemoveServer(v)
@ -69,7 +72,7 @@ func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
} }
// Finish implements task.TaskFinisher. // Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason string) { func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason) lb.task.Finish(reason)
} }
@ -128,7 +131,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.rebalance() lb.rebalance()
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size()) logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
} }
func (lb *LoadBalancer) RemoveServer(srv *Server) { func (lb *LoadBalancer) RemoveServer(srv *Server) {
@ -147,11 +150,11 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
if lb.pool.Size() == 0 { if lb.pool.Size() == 0 {
lb.task.Finish("no server left") lb.task.Finish("no server left")
logger.Infof("[remove] loadbalancer %s stopped", lb.Link) logger.Infof("loadbalancer %s stopped", lb.Link)
return return
} }
logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size()) logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
} }
func (lb *LoadBalancer) rebalance() { func (lb *LoadBalancer) rebalance() {
@ -211,6 +214,21 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return return
} }
if r.Header.Get(common.HeaderCheckRedirect) != "" {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
defer cancel()
// send dummy request to wake all servers
var dummyRW *DummyResponseWriter
for _, srv := range srvs {
// wake only if server implements Waker
_, ok := srv.handler.(idlewatcher.Waker)
if !ok {
continue
}
wakeReq := r.Clone(ctx)
srv.ServeHTTP(dummyRW, wakeReq)
}
}
lb.impl.ServeHTTP(srvs, rw, r) lb.impl.ServeHTTP(srvs, rw, r)
} }
@ -261,10 +279,9 @@ func (lb *LoadBalancer) String() string {
func (lb *LoadBalancer) availServers() []*Server { func (lb *LoadBalancer) availServers() []*Server {
avail := make([]*Server, 0, lb.pool.Size()) avail := make([]*Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv *Server) { lb.pool.RangeAll(func(_ string, srv *Server) {
if srv.Status().Bad() { if srv.Status().Good() {
return avail = append(avail, srv)
} }
avail = append(avail, srv)
}) })
return avail return avail
} }

View file

@ -14,8 +14,8 @@ func (lb *roundRobin) OnAddServer(srv *Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {} func (lb *roundRobin) OnRemoveServer(srv *Server) {}
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) { func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1) index := lb.index.Add(1) % uint32(len(srvs))
srvs[index%uint32(len(srvs))].ServeHTTP(rw, r) srvs[index].ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) { if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0) lb.index.Store(0)
} }

View file

@ -35,7 +35,7 @@ var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
} }
} }
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) { func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
wl := new(cidrWhitelist) wl := new(cidrWhitelist)
wl.m = &Middleware{ wl.m = &Middleware{
impl: wl, impl: wl,

View file

@ -33,7 +33,7 @@ var CloudflareRealIP = &realIP{
m: &Middleware{withOptions: NewCloudflareRealIP}, m: &Middleware{withOptions: NewCloudflareRealIP},
} }
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) { func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
cri := new(realIP) cri := new(realIP)
cri.m = &Middleware{ cri.m = &Middleware{
impl: cri, impl: cri,

View file

@ -36,7 +36,7 @@ var ForwardAuth = &forwardAuth{
m: &Middleware{withOptions: NewForwardAuthfunc}, m: &Middleware{withOptions: NewForwardAuthfunc},
} }
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) { func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth) fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts) fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, fa.forwardAuthOpts) err := Deserialize(optsRaw, fa.forwardAuthOpts)

View file

@ -11,7 +11,7 @@ import (
) )
type ( type (
Error = E.NestedError Error = E.Error
ReverseProxy = gphttp.ReverseProxy ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest ProxyRequest = gphttp.ProxyRequest
@ -24,7 +24,7 @@ type (
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request) BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request) RewriteFunc func(req *Request)
ModifyResponseFunc func(resp *Response) error ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError) CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any OptionsRaw = map[string]any
Options any Options any
@ -77,7 +77,7 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ") }, "", " ")
} }
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) { func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if len(optsRaw) != 0 && m.withOptions != nil { if len(optsRaw) != 0 && m.withOptions != nil {
return m.withOptions(optsRaw) return m.withOptions(optsRaw)
} }
@ -108,7 +108,7 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
} }
// TODO: check conflict or duplicates. // TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) { func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) {
middlewares = make([]*Middleware, 0, len(middlewaresMap)) middlewares = make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares") invalidM := E.NewBuilder("invalid middlewares")
@ -136,7 +136,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
return return
} }
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) { func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap) middlewares, err = createMiddlewares(middlewaresMap)
if err != nil { if err != nil {

View file

@ -10,7 +10,7 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) { func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) {
fileContent, err := os.ReadFile(filePath) fileContent, err := os.ReadFile(filePath)
if err != nil { if err != nil {
return nil, E.FailWith("read middleware compose file", err) return nil, E.FailWith("read middleware compose file", err)
@ -18,7 +18,7 @@ func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E
return BuildMiddlewaresFromYAML(fileContent) return BuildMiddlewaresFromYAML(fileContent)
} }
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) { func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) {
b := E.NewBuilder("middlewares compile errors") b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr) defer b.To(&outErr)

View file

@ -22,7 +22,7 @@ var ModifyRequest = &modifyRequest{
m: &Middleware{withOptions: NewModifyRequest}, m: &Middleware{withOptions: NewModifyRequest},
} }
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) { func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest) mr := new(modifyRequest)
var mrFunc RewriteFunc var mrFunc RewriteFunc
if common.IsDebug { if common.IsDebug {

View file

@ -24,7 +24,7 @@ var ModifyResponse = &modifyResponse{
m: &Middleware{withOptions: NewModifyResponse}, m: &Middleware{withOptions: NewModifyResponse},
} }
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) { func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse) mr := new(modifyResponse)
mr.m = &Middleware{impl: mr} mr.m = &Middleware{impl: mr}
if common.IsDebug { if common.IsDebug {

View file

@ -26,7 +26,7 @@ var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2}, m: &Middleware{withOptions: NewAuthentikOAuth2},
} }
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) { func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
oauth := new(oAuth2) oauth := new(oAuth2)
oauth.m = &Middleware{ oauth.m = &Middleware{
impl: oauth, impl: oauth,

View file

@ -41,7 +41,7 @@ var realIPOptsDefault = func() *realIPOpts {
} }
} }
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) { func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
riWithOpts := new(realIP) riWithOpts := new(realIP)
riWithOpts.m = &Middleware{ riWithOpts.m = &Middleware{
impl: riWithOpts, impl: riWithOpts,

View file

@ -72,7 +72,7 @@ type testArgs struct {
scheme string scheme string
} }
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) { func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
var body io.Reader var body io.Reader
var rr requestRecorder var rr requestRecorder
var proxyURL *url.URL var proxyURL *url.URL

View file

@ -9,7 +9,7 @@ import (
type CIDR net.IPNet type CIDR net.IPNet
func (cidr *CIDR) ConvertFrom(val any) E.NestedError { func (cidr *CIDR) ConvertFrom(val any) E.Error {
cidrStr, ok := val.(string) cidrStr, ok := val.(string)
if !ok { if !ok {
return E.TypeMismatch[string](val) return E.TypeMismatch[string](val)

View file

@ -7,13 +7,7 @@ import (
type Stream interface { type Stream interface {
fmt.Stringer fmt.Stringer
net.Listener
Setup() error Setup() error
Accept() (conn StreamConn, err error) Handle(conn net.Conn) error
Handle(conn StreamConn) error
CloseListeners()
}
type StreamConn interface {
RemoteAddr() net.Addr
Close() error
} }

View file

@ -1,7 +1,7 @@
package entry package entry
import ( import (
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
@ -18,7 +18,7 @@ type Entry interface {
IdlewatcherConfig() *idlewatcher.Config IdlewatcherConfig() *idlewatcher.Config
} }
func ValidateEntry(m *RawEntry) (Entry, E.NestedError) { func ValidateEntry(m *RawEntry) (Entry, E.Error) {
m.FillMissingFields() m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme) scheme, err := T.NewScheme(m.Scheme)

View file

@ -5,7 +5,7 @@ import (
"net/url" "net/url"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"

View file

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"

View file

@ -7,7 +7,7 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) { func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.Error) {
h := make(http.Header) h := make(http.Header)
for k, v := range headers { for k, v := range headers {
vSplit := strings.Split(v, ",") vSplit := strings.Split(v, ",")

View file

@ -9,6 +9,6 @@ type (
Subdomain = Alias Subdomain = Alias
) )
func ValidateHost[String ~string](s String) (Host, E.NestedError) { func ValidateHost[String ~string](s String) (Host, E.Error) {
return Host(s), nil return Host(s), nil
} }

View file

@ -13,7 +13,7 @@ type (
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
func ValidatePathPattern(s string) (PathPattern, E.NestedError) { func ValidatePathPattern(s string) (PathPattern, E.Error) {
if len(s) == 0 { if len(s) == 0 {
return "", E.Invalid("path", "must not be empty") return "", E.Invalid("path", "must not be empty")
} }
@ -23,7 +23,7 @@ func ValidatePathPattern(s string) (PathPattern, E.NestedError) {
return PathPattern(s), nil return PathPattern(s), nil
} }
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) { func ValidatePathPatterns(s []string) (PathPatterns, E.Error) {
if len(s) == 0 { if len(s) == 0 {
return []PathPattern{"/"}, nil return []PathPattern{"/"}, nil
} }

View file

@ -8,7 +8,7 @@ import (
type Port int type Port int
func ValidatePort[String ~string](v String) (Port, E.NestedError) { func ValidatePort[String ~string](v String) (Port, E.Error) {
p, err := strconv.Atoi(string(v)) p, err := strconv.Atoi(string(v))
if err != nil { if err != nil {
return ErrPort, E.Invalid("port number", v).With(err) return ErrPort, E.Invalid("port number", v).With(err)
@ -16,7 +16,7 @@ func ValidatePort[String ~string](v String) (Port, E.NestedError) {
return ValidatePortInt(p) return ValidatePortInt(p)
} }
func ValidatePortInt[Int int | uint16](v Int) (Port, E.NestedError) { func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) {
p := Port(v) p := Port(v)
if !p.inBound() { if !p.inBound() {
return ErrPort, E.OutOfRange("port", p) return ErrPort, E.OutOfRange("port", p)

View file

@ -6,7 +6,7 @@ import (
type Scheme string type Scheme string
func NewScheme[String ~string](s String) (Scheme, E.NestedError) { func NewScheme[String ~string](s String) (Scheme, E.Error) {
switch s { switch s {
case "http", "https", "tcp", "udp": case "http", "https", "tcp", "udp":
return Scheme(s), nil return Scheme(s), nil

View file

@ -12,7 +12,7 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"` ProxyPort Port `json:"proxy"`
} }
func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) { func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
split := strings.Split(p, ":") split := strings.Split(p, ":")
switch len(split) { switch len(split) {
@ -47,7 +47,7 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) {
return StreamPort{listeningPort, proxyPort}, nil return StreamPort{listeningPort, proxyPort}, nil
} }
func parseNameToPort(name string) (Port, E.NestedError) { func parseNameToPort(name string) (Port, E.Error) {
port, ok := common.ServiceNamePortMapTCP[name] port, ok := common.ServiceNamePortMapTCP[name]
if !ok { if !ok {
return ErrPort, E.Invalid("service", name) return ErrPort, E.Invalid("service", name)

View file

@ -12,7 +12,7 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"` ProxyScheme Scheme `json:"proxy"`
} }
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) {
ss = &StreamScheme{} ss = &StreamScheme{}
parts := strings.Split(s, ":") parts := strings.Split(s, ":")
if len(parts) == 1 { if len(parts) == 1 {

View file

@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
} }
} }
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) { func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
var trans *http.Transport var trans *http.Transport
if entry.NoTLSVerify { if entry.NoTLSVerify {
@ -97,7 +97,7 @@ func (r *HTTPRoute) String() string {
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError { func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
if entry.ShouldNotServe(r) { if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve") providerSubtask.Finish("should not serve")
return nil return nil
@ -151,7 +151,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
r.addToLoadBalancer() r.addToLoadBalancer()
} else { } else {
httpRoutes.Store(string(r.Alias), r) httpRoutes.Store(string(r.Alias), r)
r.task.OnComplete("stop rp", func() { r.task.OnFinished("remove from route table", func() {
httpRoutes.Delete(string(r.Alias)) httpRoutes.Delete(string(r.Alias))
}) })
} }
@ -160,7 +160,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
} }
// Finish implements task.TaskFinisher. // Finish implements task.TaskFinisher.
func (r *HTTPRoute) Finish(reason string) { func (r *HTTPRoute) Finish(reason any) {
r.task.Finish(reason) r.task.Finish(reason)
} }
@ -175,8 +175,8 @@ func (r *HTTPRoute) addToLoadBalancer() {
} }
} else { } else {
lb = loadbalancer.New(r.LoadBalance) lb = loadbalancer.New(r.LoadBalance)
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link) lbTask := r.task.Parent().Subtask("loadbalancer " + r.LoadBalance.Link)
lbTask.OnComplete("remove lb from routes", func() { lbTask.OnCancel("remove lb from routes", func() {
httpRoutes.Delete(r.LoadBalance.Link) httpRoutes.Delete(r.LoadBalance.Link)
}) })
lb.Start(lbTask) lb.Start(lbTask)
@ -194,9 +194,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
httpRoutes.Store(r.LoadBalance.Link, linked) httpRoutes.Store(r.LoadBalance.Link, linked)
} }
r.loadBalancer = lb r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) r.server = loadbalancer.NewServer(r.task.String(), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server) lb.AddServer(r.server)
r.task.OnComplete("remove server from lb", func() { r.task.OnCancel("remove server from lb", func() {
lb.RemoveServer(r.server) lb.RemoveServer(r.server)
}) })
} }

View file

@ -25,7 +25,7 @@ var (
AliasRefRegexOld = regexp.MustCompile(`\$\d+`) AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
) )
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) { func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) {
if dockerHost == common.DockerHostFromEnv { if dockerHost == common.DockerHostFromEnv {
dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost) dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost)
} }
@ -40,18 +40,18 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost) return W.NewDockerWatcher(p.dockerHost)
} }
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) {
routes = R.NewRoutes() routes = R.NewRoutes()
entries := entry.NewProxyEntries() entries := entry.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true) containers, err := D.ListContainers(p.dockerHost)
if err != nil { if err != nil {
return routes, E.FailWith("connect to docker", err) return routes, err
} }
errors := E.NewBuilder("errors in docker labels") errors := E.NewBuilder("errors in docker labels")
for _, c := range info.Containers { for _, c := range containers {
container := D.FromDocker(&c, p.dockerHost) container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded { if container.IsExcluded {
continue continue
@ -70,10 +70,6 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
}) })
} }
entries.RangeAll(func(_ string, e *entry.RawEntry) {
e.Container.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries) routes, err = R.FromEntries(entries)
errors.Add(err) errors.Add(err)
@ -89,7 +85,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
// Returns a list of proxy entries for a container. // Returns a list of proxy entries for a container.
// Always non-nil. // Always non-nil.
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) { func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) {
entries = entry.NewProxyEntries() entries = entry.NewProxyEntries()
if p.shouldIgnore(container) { if p.shouldIgnore(container) {
@ -117,7 +113,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
return entries, errors.Build().Subject(container.ContainerName) return entries, errors.Build().Subject(container.ContainerName)
} }
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) { func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) {
b := E.NewBuilder("errors in label %s", key) b := E.NewBuilder("errors in label %s", key)
defer b.To(&res) defer b.To(&res)

View file

@ -1,7 +1,9 @@
package provider package provider
import ( import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher"
@ -32,31 +34,52 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
return return
} }
oldRoutes.RangeAll(func(k string, v *route.Route) { if common.IsDebug {
if !newRoutes.Has(k) { eventsLog := E.NewBuilder("events")
handler.Remove(v) for _, event := range events {
eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID)
}
handler.provider.l.Debug(eventsLog.String())
oldRoutesLog := E.NewBuilder("old routes")
oldRoutes.RangeAll(func(k string, r *route.Route) {
oldRoutesLog.Addf(k)
})
handler.provider.l.Debug(oldRoutesLog.String())
newRoutesLog := E.NewBuilder("new routes")
newRoutes.RangeAll(func(k string, r *route.Route) {
newRoutesLog.Addf(k)
})
handler.provider.l.Debug(newRoutesLog.String())
}
oldRoutes.RangeAll(func(k string, oldr *route.Route) {
newr, ok := newRoutes.Load(k)
if !ok {
handler.Remove(oldr)
} else if handler.matchAny(events, newr) {
handler.Update(parent, oldr, newr)
} else if entry.ShouldNotServe(newr) {
handler.Remove(oldr)
} }
}) })
newRoutes.RangeAll(func(k string, newr *route.Route) { newRoutes.RangeAll(func(k string, newr *route.Route) {
if oldRoutes.Has(k) { if !(oldRoutes.Has(k) || entry.ShouldNotServe(newr)) {
for _, ev := range events {
if handler.match(ev, newr) {
old, ok := oldRoutes.Load(k)
if !ok { // should not happen
panic("race condition")
}
handler.Update(parent, old, newr)
return
}
}
} else {
handler.Add(parent, newr) handler.Add(parent, newr)
} }
}) })
} }
func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route) bool {
for _, event := range events {
if handler.match(event, route) {
return true
}
}
return false
}
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.t { switch handler.provider.GetType() {
case ProviderTypeDocker: case ProviderTypeDocker:
return route.Entry.Container.ContainerID == event.ActorID || return route.Entry.Container.ContainerID == event.ActorID ||
route.Entry.Container.ContainerName == event.ActorName route.Entry.Container.ContainerName == event.ActorName
@ -70,14 +93,15 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
func (handler *EventHandler) Add(parent task.Task, route *route.Route) { func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
err := handler.provider.startRoute(parent, route) err := handler.provider.startRoute(parent, route)
if err != nil { if err != nil {
handler.errs.Add(err) handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err))
} else { } else {
handler.added = append(handler.added, route.Entry.Alias) handler.added = append(handler.added, route.Entry.Alias)
} }
} }
func (handler *EventHandler) Remove(route *route.Route) { func (handler *EventHandler) Remove(route *route.Route) {
route.Finish("route removal") route.Finish("route removed")
handler.provider.routes.Delete(route.Entry.Alias)
handler.removed = append(handler.removed, route.Entry.Alias) handler.removed = append(handler.removed, route.Entry.Alias)
} }
@ -85,7 +109,7 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
oldRoute.Finish("route update") oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute) err := handler.provider.startRoute(parent, newRoute)
if err != nil { if err != nil {
handler.errs.Add(err) handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err))
} else { } else {
handler.updated = append(handler.updated, newRoute.Entry.Alias) handler.updated = append(handler.updated, newRoute.Entry.Alias)
} }

View file

@ -18,7 +18,7 @@ type FileProvider struct {
path string path string
} }
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) { func FileProviderImpl(filename string) (ProviderImpl, E.Error) {
impl := &FileProvider{ impl := &FileProvider{
fileName: filename, fileName: filename,
path: path.Join(common.ConfigBasePath, filename), path: path.Join(common.ConfigBasePath, filename),
@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
} }
} }
func Validate(data []byte) E.NestedError { func Validate(data []byte) E.Error {
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data) return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
} }
@ -42,7 +42,7 @@ func (p FileProvider) String() string {
return p.fileName return p.fileName
} }
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) { func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) {
routes = R.NewRoutes() routes = R.NewRoutes()
b := E.NewBuilder("validation failure") b := E.NewBuilder("validation failure")

View file

@ -7,7 +7,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
@ -29,7 +28,7 @@ type (
ProviderImpl interface { ProviderImpl interface {
fmt.Stringer fmt.Stringer
NewWatcher() W.Watcher NewWatcher() W.Watcher
LoadRoutesImpl() (R.Routes, E.NestedError) LoadRoutesImpl() (R.Routes, E.Error)
} }
ProviderType string ProviderType string
ProviderStats struct { ProviderStats struct {
@ -43,7 +42,7 @@ const (
ProviderTypeDocker ProviderType = "docker" ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file" ProviderTypeFile ProviderType = "file"
providerEventFlushInterval = 500 * time.Millisecond providerEventFlushInterval = 300 * time.Millisecond
) )
func newProvider(name string, t ProviderType) *Provider { func newProvider(name string, t ProviderType) *Provider {
@ -56,7 +55,7 @@ func newProvider(name string, t ProviderType) *Provider {
return p return p
} }
func NewFileProvider(filename string) (p *Provider, err E.NestedError) { func NewFileProvider(filename string) (p *Provider, err E.Error) {
name := path.Base(filename) name := path.Base(filename)
if name == "" { if name == "" {
return nil, E.Invalid("file name", "empty") return nil, E.Invalid("file name", "empty")
@ -70,7 +69,7 @@ func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
return return
} }
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) { func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) {
if name == "" { if name == "" {
return nil, E.Invalid("provider name", "empty") return nil, E.Invalid("provider name", "empty")
} }
@ -101,18 +100,16 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil return []byte(p.String()), nil
} }
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError { func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
if entry.UseLoadBalance(r) { subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
r.Entry.Alias = p.String() + "/" + r.Entry.Alias
}
subtask := parent.Subtask(r.Entry.Alias)
err := r.Start(subtask) err := r.Start(subtask)
if err != nil { if err != nil {
p.routes.Delete(r.Entry.Alias) p.routes.Delete(r.Entry.Alias)
subtask.Finish(err.String()) // just to ensure subtask.Finish(err) // just to ensure
return err return err
} else { } else {
subtask.OnComplete("del from provider", func() { p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() {
p.routes.Delete(r.Entry.Alias) p.routes.Delete(r.Entry.Alias)
}) })
} }
@ -120,7 +117,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) { func (p *Provider) Start(configSubtask task.Task) (res E.Error) {
errors := E.NewBuilder("errors starting routes") errors := E.NewBuilder("errors starting routes")
defer errors.To(&res) defer errors.To(&res)
@ -141,7 +138,7 @@ func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
handler.Log() handler.Log()
flushTask.Finish("events flushed") flushTask.Finish("events flushed")
}, },
func(err E.NestedError) { func(err E.Error) {
p.l.Error(err) p.l.Error(err)
}, },
) )
@ -157,8 +154,8 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) {
return p.routes.Load(alias) return p.routes.Load(alias)
} }
func (p *Provider) LoadRoutes() E.NestedError { func (p *Provider) LoadRoutes() E.Error {
var err E.NestedError var err E.Error
p.routes, err = p.LoadRoutesImpl() p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 { if p.routes.Size() > 0 {
return err return err

94
internal/route/raw.go Normal file
View file

@ -0,0 +1,94 @@
package route
import (
"errors"
"fmt"
"net"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
RawStream struct {
*StreamRoute
listener net.Listener
targetAddr net.Addr
}
)
const (
streamBufferSize = 8192
streamDialTimeout = 5 * time.Second
)
func NewRawStreamRoute(base *StreamRoute) *RawStream {
return &RawStream{
StreamRoute: base,
}
}
func (route *RawStream) Setup() error {
var lcfg net.ListenConfig
var err error
switch route.Scheme.ListeningScheme {
case "tcp":
route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
return err
}
tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
route.listener = tcpListener
case "udp":
route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
return err
}
udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port)
route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener)
default:
return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme))
}
return nil
}
func (route *RawStream) Accept() (net.Conn, error) {
if route.listener == nil {
return nil, errors.New("listener not yet set up")
}
return route.listener.Accept()
}
func (route *RawStream) Handle(c net.Conn) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
route.task.OnCancel("close conn", func() { clientConn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout}
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr)
if err != nil {
return err
}
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
return pipe.Start()
}
func (route *RawStream) Close() error {
return route.listener.Close()
}

View file

@ -44,7 +44,7 @@ func (rt *Route) Container() *docker.Container {
return rt.Entry.Container return rt.Entry.Container
} }
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) { func NewRoute(raw *entry.RawEntry) (*Route, E.Error) {
en, err := entry.ValidateEntry(raw) en, err := entry.ValidateEntry(raw)
if err != nil { if err != nil {
return nil, err return nil, err
@ -73,7 +73,7 @@ func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
}, nil }, nil
} }
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) { func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
b := E.NewBuilder("errors in routes") b := E.NewBuilder("errors in routes")
routes := NewRoutes() routes := NewRoutes()

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
stdNet "net"
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -37,7 +36,7 @@ func GetStreamProxies() F.Map[string, *StreamRoute] {
return streamRoutes return streamRoutes
} }
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) { func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
// TODO: support non-coherent scheme // TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() { if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
@ -48,16 +47,12 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
}, nil }, nil
} }
func (r *StreamRoute) Finish(reason string) {
r.task.Finish(reason)
}
func (r *StreamRoute) String() string { func (r *StreamRoute) String() string {
return fmt.Sprintf("stream %s", r.Alias) return fmt.Sprintf("stream %s", r.Alias)
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError { func (r *StreamRoute) Start(providerSubtask task.Task) E.Error {
if entry.ShouldNotServe(r) { if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve") providerSubtask.Finish("should not serve")
return nil return nil
@ -71,11 +66,13 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
r.HealthCheck.Disable = true r.HealthCheck.Disable = true
} }
if r.Scheme.ListeningScheme.IsTCP() { // if r.Scheme.ListeningScheme.IsTCP() {
r.Stream = NewTCPRoute(r) // r.Stream = NewTCPRoute(r)
} else { // } else {
r.Stream = NewUDPRoute(r) // r.Stream = NewUDPRoute(r)
} // }
r.task = providerSubtask
r.Stream = NewRawStreamRoute(r)
r.l = logrus.WithField("route", r.Stream.String()) r.l = logrus.WithField("route", r.Stream.String())
switch { switch {
@ -83,6 +80,7 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias)) wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream) waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
if err != nil { if err != nil {
r.task.Finish(err)
return err return err
} }
r.Stream = waker r.Stream = waker
@ -90,24 +88,41 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
case entry.UseHealthCheck(r): case entry.UseHealthCheck(r):
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck) r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
} }
r.task = providerSubtask
r.task.OnComplete("stop stream", r.CloseListeners)
if err := r.Setup(); err != nil { if err := r.Setup(); err != nil {
r.task.Finish(err)
return E.FailWith("setup", err) return E.FailWith("setup", err)
} }
r.l.Infof("listening on port %d", r.Port.ListeningPort)
go r.acceptConnections() r.task.OnFinished("close stream", func() {
if err := r.Close(); err != nil {
r.l.Error("close stream error: ", err)
}
})
r.task.OnFinished("remove from route table", func() {
streamRoutes.Delete(string(r.Alias))
})
r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort)
if r.HealthMon != nil { if r.HealthMon != nil {
r.HealthMon.Start(r.task.Subtask("health monitor")) if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn("health monitor error: ", err)
}
} }
go r.acceptConnections()
streamRoutes.Store(string(r.Alias), r) streamRoutes.Store(string(r.Alias), r)
return nil return nil
} }
func (r *StreamRoute) Finish(reason any) {
r.task.Finish(reason)
}
func (r *StreamRoute) acceptConnections() { func (r *StreamRoute) acceptConnections() {
defer r.task.Finish("listener closed")
for { for {
select { select {
case <-r.task.Context().Done(): case <-r.task.Context().Done():
@ -117,24 +132,17 @@ func (r *StreamRoute) acceptConnections() {
if err != nil { if err != nil {
select { select {
case <-r.task.Context().Done(): case <-r.task.Context().Done():
return
default: default:
var nErr *stdNet.OpError r.l.Error("accept connection error: ", err)
ok := errors.As(err, &nErr) r.task.Finish(err)
if !(ok && nErr.Timeout()) {
r.l.Error("accept connection error: ", err)
r.task.Finish(err.Error())
return
}
continue
} }
return
} }
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String()) connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr()))
go func() { go func() {
err := r.Handle(conn) err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err) r.l.Error(err)
connTask.Finish(err.Error())
} else { } else {
connTask.Finish("connection closed") connTask.Finish("connection closed")
} }

View file

@ -1,71 +1,68 @@
package route package route
import ( // import (
"context" // "context"
"fmt" // "fmt"
"net" // "net"
"time" // "time"
"github.com/yusing/go-proxy/internal/net/types" // "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields" // T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils" // U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" // F "github.com/yusing/go-proxy/internal/utils/functional"
) // )
const tcpDialTimeout = 5 * time.Second // const tcpDialTimeout = 5 * time.Second
type ( // type (
TCPConnMap = F.Map[net.Conn, struct{}] // TCPConnMap = F.Map[net.Conn, struct{}]
TCPRoute struct { // TCPRoute struct {
*StreamRoute // *StreamRoute
listener *net.TCPListener // listener *net.TCPListener
} // }
) // )
func NewTCPRoute(base *StreamRoute) *TCPRoute { // func NewTCPRoute(base *StreamRoute) *TCPRoute {
return &TCPRoute{StreamRoute: base} // return &TCPRoute{StreamRoute: base}
} // }
func (route *TCPRoute) Setup() error { // func (route *TCPRoute) Setup() error {
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) // var cfg net.ListenConfig
if err != nil { // in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
return err // if err != nil {
} // return err
//! this read the allocated port from original ':0' // }
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) // //! this read the allocated port from original ':0'
route.listener = in.(*net.TCPListener) // route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
return nil // route.listener = in.(*net.TCPListener)
} // return nil
// }
func (route *TCPRoute) Accept() (types.StreamConn, error) { // func (route *TCPRoute) Accept() (types.StreamConn, error) {
route.listener.SetDeadline(time.Now().Add(time.Second)) // return route.listener.Accept()
return route.listener.Accept() // }
}
func (route *TCPRoute) Handle(c types.StreamConn) error { // func (route *TCPRoute) Handle(c types.StreamConn) error {
clientConn := c.(net.Conn) // clientConn := c.(net.Conn)
defer clientConn.Close() // defer clientConn.Close()
route.task.OnComplete("close conn", func() { clientConn.Close() }) // route.task.OnCancel("close conn", func() { clientConn.Close() })
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) // ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) // serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{} // dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) // serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
cancel() // cancel()
if err != nil { // if err != nil {
return err // return err
} // }
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) // pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
return pipe.Start() // return pipe.Start()
} // }
func (route *TCPRoute) CloseListeners() { // func (route *TCPRoute) Close() error {
if route.listener == nil { // return route.listener.Close()
return // }
}
route.listener.Close()
}

View file

@ -1,145 +1,149 @@
package route package route
import ( // import (
"errors" // "errors"
"fmt" // "fmt"
"io" // "io"
"net" // "net"
"time"
"github.com/yusing/go-proxy/internal/net/types" // "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields" // T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils" // U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" // F "github.com/yusing/go-proxy/internal/utils/functional"
) // )
type ( // type (
UDPRoute struct { // UDPRoute struct {
*StreamRoute // *StreamRoute
connMap UDPConnMap // connMap UDPConnMap
listeningConn *net.UDPConn // listeningConn net.PacketConn
targetAddr *net.UDPAddr // targetAddr *net.UDPAddr
} // }
UDPConn struct { // UDPConn struct {
key string // key string
src *net.UDPConn // src net.Conn
dst *net.UDPConn // dst net.Conn
U.BidirectionalPipe // U.BidirectionalPipe
} // }
UDPConnMap = F.Map[string, *UDPConn] // UDPConnMap = F.Map[string, *UDPConn]
) // )
var NewUDPConnMap = F.NewMap[UDPConnMap] // var NewUDPConnMap = F.NewMap[UDPConnMap]
const udpBufferSize = 8192 // const udpBufferSize = 8192
func NewUDPRoute(base *StreamRoute) *UDPRoute { // func NewUDPRoute(base *StreamRoute) *UDPRoute {
return &UDPRoute{ // return &UDPRoute{
StreamRoute: base, // StreamRoute: base,
connMap: NewUDPConnMap(), // connMap: NewUDPConnMap(),
} // }
} // }
func (route *UDPRoute) Setup() error { // func (route *UDPRoute) Setup() error {
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort)) // var cfg net.ListenConfig
if err != nil { // source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
return err // if err != nil {
} // return err
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr) // }
if err != nil { // raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
return err // if err != nil {
} // source.Close()
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) // return err
if err != nil { // }
source.Close()
return err
}
//! this read the allocated listeningPort from original ':0' // //! this read the allocated listeningPort from original ':0'
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) // route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source // route.listeningConn = source
route.targetAddr = raddr // route.targetAddr = raddr
return nil // return nil
} // }
func (route *UDPRoute) Accept() (types.StreamConn, error) { // func (route *UDPRoute) Accept() (types.StreamConn, error) {
in := route.listeningConn // in := route.listeningConn
buffer := make([]byte, udpBufferSize) // buffer := make([]byte, udpBufferSize)
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second)) // nRead, srcAddr, err := in.ReadFrom(buffer)
nRead, srcAddr, err := in.ReadFromUDP(buffer) // if err != nil {
if err != nil { // return nil, err
return nil, err // }
}
if nRead == 0 { // if nRead == 0 {
return nil, io.ErrShortBuffer // return nil, io.ErrShortBuffer
} // }
key := srcAddr.String() // key := srcAddr.String()
conn, ok := route.connMap.Load(key) // conn, ok := route.connMap.Load(key)
if !ok { // if !ok {
srcConn, err := net.DialUDP("udp", nil, srcAddr) // srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String())
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
dstConn, err := net.DialUDP("udp", nil, route.targetAddr) // dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String())
if err != nil { // if err != nil {
srcConn.Close() // srcConn.Close()
return nil, err // return nil, err
} // }
conn = &UDPConn{ // conn = &UDPConn{
key, // key,
srcConn, // srcConn,
dstConn, // dstConn,
U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), // U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
} // }
route.connMap.Store(key, conn) // route.connMap.Store(key, conn)
} // }
_, err = conn.dst.Write(buffer[:nRead]) // _, err = conn.dst.Write(buffer[:nRead])
return conn, err // return conn, err
} // }
func (route *UDPRoute) Handle(c types.StreamConn) error { // func (route *UDPRoute) Handle(c types.StreamConn) error {
conn := c.(*UDPConn) // switch c := c.(type) {
err := conn.Start() // case *UDPConn:
route.connMap.Delete(conn.key) // err := c.Start()
return err // route.connMap.Delete(c.key)
} // c.Close()
// return err
// case *net.TCPConn:
// in := route.listeningConn
// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr))
// if err != nil {
// return err
// }
// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start()
// c.Close()
// return err
// }
// return fmt.Errorf("unknown conn type: %T", c)
// }
func (route *UDPRoute) CloseListeners() { // func (route *UDPRoute) Close() error {
if route.listeningConn != nil { // route.connMap.RangeAllParallel(func(k string, v *UDPConn) {
route.listeningConn.Close() // v.Close()
} // })
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) { // route.connMap.Clear()
if err := conn.Close(); err != nil { // return route.listeningConn.Close()
route.l.Errorf("error closing conn: %s", err) // }
}
})
route.connMap.Clear()
}
// Close implements types.StreamConn // // Close implements types.StreamConn
func (conn *UDPConn) Close() error { // func (conn *UDPConn) Close() error {
return errors.Join(conn.src.Close(), conn.dst.Close()) // return errors.Join(conn.src.Close(), conn.dst.Close())
} // }
// RemoteAddr implements types.StreamConn // // RemoteAddr implements types.StreamConn
func (conn *UDPConn) RemoteAddr() net.Addr { // func (conn *UDPConn) RemoteAddr() net.Addr {
return conn.src.RemoteAddr() // return conn.src.RemoteAddr()
} // }
type sourceRWCloser struct { // type sourceRWCloser struct {
server *net.UDPConn // server net.PacketConn
*net.UDPConn // net.Conn
} // }
func (w sourceRWCloser) Write(p []byte) (int, error) { // func (w sourceRWCloser) Write(p []byte) (int, error) {
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp // return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr))
} // }

View file

@ -0,0 +1,73 @@
package route
import (
"context"
"io"
"net"
"sync"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
UDPListener struct {
ctx context.Context
listener net.PacketConn
connMap UDPConnMap
mu sync.Mutex
}
UDPConnMap = F.Map[string, net.Conn]
)
var NewUDPConnMap = F.NewMap[UDPConnMap]
func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener {
return &UDPListener{
ctx: ctx,
listener: listener,
connMap: NewUDPConnMap(),
}
}
// Addr implements net.Listener.
func (route *UDPListener) Addr() net.Addr {
return route.listener.LocalAddr()
}
func (udpl *UDPListener) Accept() (net.Conn, error) {
in := udpl.listener
buffer := make([]byte, streamBufferSize)
nRead, srcAddr, err := in.ReadFrom(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
udpl.mu.Lock()
defer udpl.mu.Unlock()
key := srcAddr.String()
conn, ok := udpl.connMap.Load(key)
if !ok {
dialer := &net.Dialer{Timeout: streamDialTimeout}
srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String())
if err != nil {
return nil, err
}
udpl.connMap.Store(key, srcConn)
}
return conn, nil
}
// Close implements net.Listener.
func (route *UDPListener) Close() error {
route.connMap.RangeAllParallel(func(key string, conn net.Conn) {
conn.Close()
})
route.connMap.Clear()
return route.listener.Close()
}

View file

@ -116,7 +116,7 @@ func (s *Server) Start() {
}() }()
} }
s.task.OnComplete("stop server", s.stop) s.task.OnFinished("stop server", s.stop)
} }
func (s *Server) stop() { func (s *Server) stop() {

View file

@ -39,10 +39,11 @@ type (
// Use Task.Finish to stop all subtasks of the task. // Use Task.Finish to stop all subtasks of the task.
Task interface { Task interface {
TaskFinisher TaskFinisher
fmt.Stringer
// Name returns the name of the task. // Name returns the name of the task.
Name() string Name() string
// Context returns the context associated with the task. This context is // Context returns the context associated with the task. This context is
// canceled when Finish is called. // canceled when Finish of the task is called, or parent task is canceled.
Context() context.Context Context() context.Context
// FinishCause returns the reason / error that caused the task to be finished. // FinishCause returns the reason / error that caused the task to be finished.
FinishCause() error FinishCause() error
@ -53,12 +54,16 @@ type (
// If the parent's context is already canceled, the returned subtask will be canceled immediately. // If the parent's context is already canceled, the returned subtask will be canceled immediately.
// //
// This should not be called after Finish, Wait, or WaitSubTasks is called. // This should not be called after Finish, Wait, or WaitSubTasks is called.
Subtask(usageFmt string, args ...any) Task Subtask(name string) Task
// OnComplete calls fn when the task and all subtasks are finished. // OnFinished calls fn when all subtasks are finished.
// //
// It cannot be called after Finish or Wait is called. // It cannot be called after Finish or Wait is called.
OnComplete(about string, fn func()) OnFinished(about string, fn func())
// Wait waits for all subtasks, itself and all OnComplete to finish. // OnCancel calls fn when the task is canceled.
//
// It cannot be called after Finish or Wait is called.
OnCancel(about string, fn func())
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
// //
// It must be called only after Finish is called. // It must be called only after Finish is called.
Wait() Wait()
@ -76,37 +81,46 @@ type (
// The task passed must be a subtask of the caller task. // The task passed must be a subtask of the caller task.
// //
// callerSubtask.Finish must be called when start fails or the object is finished. // callerSubtask.Finish must be called when start fails or the object is finished.
Start(callerSubtask Task) E.NestedError Start(callerSubtask Task) E.Error
} }
TaskFinisher interface { TaskFinisher interface {
// Finish marks the task as finished by cancelling its context. // Finish marks the task as finished and cancel its context.
// //
// Then call Wait to wait for all subtasks and OnComplete of the task to finish. // Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
// of the task to finish.
// //
// Note that it will also cancel all subtasks. // Note that it will also cancel all subtasks.
Finish(reason string) Finish(reason any)
} }
task struct { task struct {
ctx context.Context ctx context.Context
cancel context.CancelCauseFunc cancel context.CancelCauseFunc
parent *task parent *task
subtasks *xsync.MapOf[*task, struct{}] subtasks *xsync.MapOf[*task, struct{}]
subTasksWg sync.WaitGroup
name, line string name, line string
subTasksWg, onCompleteWg sync.WaitGroup OnFinishedFuncs []func()
OnFinishedMu sync.Mutex
onFinishedWg sync.WaitGroup
finishOnce sync.Once
} }
) )
var ( var (
ErrProgramExiting = errors.New("program exiting") ErrProgramExiting = errors.New("program exiting")
ErrTaskCancelled = errors.New("task cancelled") ErrTaskCanceled = errors.New("task canceled")
) )
// GlobalTask returns a new Task with the given name, derived from the global context. // GlobalTask returns a new Task with the given name, derived from the global context.
func GlobalTask(format string, args ...any) Task { func GlobalTask(format string, args ...any) Task {
return globalTask.Subtask(format, args...) if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return globalTask.Subtask(format)
} }
// DebugTaskMap returns a map[string]any representation of the global task tree. // DebugTaskMap returns a map[string]any representation of the global task tree.
@ -155,6 +169,10 @@ func (t *task) Name() string {
return t.name return t.name
} }
func (t *task) String() string {
return t.name
}
func (t *task) Context() context.Context { func (t *task) Context() context.Context {
return t.ctx return t.ctx
} }
@ -171,43 +189,83 @@ func (t *task) Parent() Task {
return t.parent return t.parent
} }
func (t *task) OnComplete(about string, fn func()) { func (t *task) runAllOnFinished(onCompTask Task) {
t.onCompleteWg.Add(1) <-t.ctx.Done()
t.WaitSubTasks()
for _, OnFinishedFunc := range t.OnFinishedFuncs {
OnFinishedFunc()
t.onFinishedWg.Done()
}
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
}
func (t *task) OnFinished(about string, fn func()) {
if t.parent == globalTask {
t.OnCancel(about, fn)
return
}
t.onFinishedWg.Add(1)
t.OnFinishedMu.Lock()
defer t.OnFinishedMu.Unlock()
if t.OnFinishedFuncs == nil {
onCompTask := GlobalTask(t.name + " > OnFinished")
go t.runAllOnFinished(onCompTask)
}
var file string var file string
var line int var line int
if common.IsTrace { if common.IsTrace {
_, file, line, _ = runtime.Caller(1) _, file, line, _ = runtime.Caller(1)
} }
go func() { idx := len(t.OnFinishedFuncs)
wrapped := func() {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err) logrus.Errorf("panic in %s > OnFinished[%d]: %q\nline %s:%d\n%v", t.name, idx, about, file, line, err)
} }
}() }()
defer t.onCompleteWg.Done() fn()
t.subTasksWg.Wait() logrus.Tracef("line %s:%d\n%s > OnFinished[%d] done: %s", file, line, t.name, idx, about)
}
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
}
func (t *task) OnCancel(about string, fn func()) {
onCompTask := GlobalTask(t.name + " > OnFinished")
go func() {
<-t.ctx.Done() <-t.ctx.Done()
fn() fn()
logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about) onCompTask.Finish("done")
t.cancel(nil) // ensure resources are released logrus.Tracef("%s > onCancel done: %s", t.name, about)
}() }()
} }
func (t *task) Finish(reason string) { func (t *task) Finish(reason any) {
t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason)) var format string
t.Wait() switch reason.(type) {
case error:
format = "%w"
case string, fmt.Stringer:
format = "%s"
default:
format = "%v"
}
t.finishOnce.Do(func() {
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
t.Wait()
})
} }
func (t *task) Subtask(format string, args ...any) Task { func (t *task) Subtask(name string) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
ctx, cancel := context.WithCancelCause(t.ctx) ctx, cancel := context.WithCancelCause(t.ctx)
return t.newSubTask(ctx, cancel, format) return t.newSubTask(ctx, cancel, name)
} }
func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task { func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task {
parent := t parent := t
if common.IsTrace {
name = parent.name + " > " + name
}
subtask := &task{ subtask := &task{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -222,10 +280,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
if ok { if ok {
subtask.line = fmt.Sprintf("%s:%d", file, line) subtask.line = fmt.Sprintf("%s:%d", file, line)
} }
logrus.Tracef("line %s\ntask %q started", subtask.line, name) logrus.Tracef("line %s\n%s started", subtask.line, name)
go func() { go func() {
subtask.Wait() subtask.Wait()
logrus.Tracef("task %q finished", subtask.Name()) logrus.Tracef("%s finished: %s", subtask.Name(), subtask.FinishCause())
}() }()
} }
go func() { go func() {
@ -237,11 +295,9 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
} }
func (t *task) Wait() { func (t *task) Wait() {
t.subTasksWg.Wait() <-t.ctx.Done()
if t != globalTask { t.WaitSubTasks()
<-t.ctx.Done() t.onFinishedWg.Wait()
}
t.onCompleteWg.Wait()
} }
func (t *task) WaitSubTasks() { func (t *task) WaitSubTasks() {
@ -270,9 +326,9 @@ func (t *task) tree(prefix ...string) string {
} }
if t.line != "" { if t.line != "" {
sb.WriteString("line " + t.line + "\n") sb.WriteString("line " + t.line + "\n")
} if len(pre) > 0 {
if len(pre) > 0 { sb.WriteString(pre + "- ")
sb.WriteString(pre + "- ") }
} }
sb.WriteString(t.Name() + "\n") sb.WriteString(t.Name() + "\n")
t.subtasks.Range(func(subtask *task, _ struct{}) bool { t.subtasks.Range(func(subtask *task, _ struct{}) bool {
@ -299,7 +355,8 @@ func (t *task) tree(prefix ...string) string {
// only. // only.
func (t *task) serialize() map[string]any { func (t *task) serialize() map[string]any {
m := make(map[string]any) m := make(map[string]any)
m["name"] = t.name parts := strings.Split(t.name, ">")
m["name"] = strings.TrimSpace(parts[len(parts)-1])
if t.line != "" { if t.line != "" {
m["line"] = t.line m["line"] = t.line
} }

View file

@ -40,7 +40,7 @@ func TestTaskCancellation(t *testing.T) {
err := subTask.Context().Err() err := subTask.Context().Err()
ExpectError(t, context.Canceled, err) ExpectError(t, context.Canceled, err)
cause := context.Cause(subTask.Context()) cause := context.Cause(subTask.Context())
ExpectError(t, ErrTaskCancelled, cause) ExpectError(t, ErrTaskCanceled, cause)
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected") t.Fatal("subTask context was not canceled as expected")
} }
@ -74,7 +74,7 @@ func TestOnComplete(t *testing.T) {
task := GlobalTask("test") task := GlobalTask("test")
var value atomic.Int32 var value atomic.Int32
task.OnComplete("set value", func() { task.OnFinished("set value", func() {
value.Store(1234) value.Store(1234)
}) })
task.Finish("done") task.Finish("done")
@ -90,10 +90,10 @@ func TestGlobalContextWait(t *testing.T) {
subTask1 := rootTask.Subtask("subtask1") subTask1 := rootTask.Subtask("subtask1")
subTask2 := rootTask.Subtask("subtask2") subTask2 := rootTask.Subtask("subtask2")
subTask1.OnComplete("set finished", func() { subTask1.OnFinished("set finished", func() {
finished1 = true finished1 = true
}) })
subTask2.OnComplete("set finished", func() { subTask2.OnFinished("set finished", func() {
finished2 = true finished2 = true
}) })
@ -117,8 +117,8 @@ func TestGlobalContextWait(t *testing.T) {
ExpectTrue(t, finished1) ExpectTrue(t, finished1)
ExpectTrue(t, finished2) ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err()) ExpectError(t, context.Canceled, rootTask.Context().Err())
ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context())) ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context())) ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
} }
func TestTimeoutOnGlobalContextWait(t *testing.T) { func TestTimeoutOnGlobalContextWait(t *testing.T) {

View file

@ -160,7 +160,7 @@ func (m Map[KT, VT]) Has(k KT) bool {
// Returns: // Returns:
// //
// error: if the unmarshaling fails // error: if the unmarshaling fails
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError { func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.Error {
if m.Size() != 0 { if m.Size() != 0 {
return E.FailedWhy("unmarshal from yaml", "map is not empty") return E.FailedWhy("unmarshal from yaml", "map is not empty")
} }

View file

@ -152,7 +152,7 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error {
return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src})
} }
func LoadJSON[T any](path string, pointer *T) E.NestedError { func LoadJSON[T any](path string, pointer *T) E.Error {
data, err := E.Check(os.ReadFile(path)) data, err := E.Check(os.ReadFile(path))
if err.HasError() { if err.HasError() {
return err return err
@ -160,7 +160,7 @@ func LoadJSON[T any](path string, pointer *T) E.NestedError {
return E.From(json.Unmarshal(data, pointer)) return E.From(json.Unmarshal(data, pointer))
} }
func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.NestedError { func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.Error {
data, err := E.Check(json.Marshal(pointer)) data, err := E.Check(json.Marshal(pointer))
if err.HasError() { if err.HasError() {
return err return err

View file

@ -19,11 +19,11 @@ import (
type ( type (
SerializedObject = map[string]any SerializedObject = map[string]any
Converter interface { Converter interface {
ConvertFrom(value any) E.NestedError ConvertFrom(value any) E.Error
} }
) )
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error {
var i any var i any
err := yaml.Unmarshal(data, &i) err := yaml.Unmarshal(data, &i)
@ -66,7 +66,7 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
// Returns: // Returns:
// - result: The resulting map[string]any representation of the data. // - result: The resulting map[string]any representation of the data.
// - error: An error if the data type is unsupported or if there is an error during conversion. // - error: An error if the data type is unsupported or if there is an error during conversion.
func Serialize(data any) (SerializedObject, E.NestedError) { func Serialize(data any) (SerializedObject, E.Error) {
result := make(map[string]any) result := make(map[string]any)
// Use reflection to inspect the data type // Use reflection to inspect the data type
@ -137,7 +137,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map. // If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
// //
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. // The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.NestedError { func Deserialize(src SerializedObject, dst any) E.Error {
if src == nil { if src == nil {
return E.Invalid("src", "nil") return E.Invalid("src", "nil")
} }
@ -210,7 +210,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
// //
// Returns: // Returns:
// - error: the error occurred during conversion, or nil if no error occurred. // - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.NestedError { func Convert(src reflect.Value, dst reflect.Value) E.Error {
srcT := src.Type() srcT := src.Type()
dstT := dst.Type() dstT := dst.Type()
@ -277,7 +277,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
return converter.ConvertFrom(src.Interface()) return converter.ConvertFrom(src.Interface())
} }
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) { func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) {
convertible = true convertible = true
if dst.Kind() == reflect.Ptr { if dst.Kind() == reflect.Ptr {
if dst.IsNil() { if dst.IsNil() {
@ -379,7 +379,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N
return true, Convert(reflect.ValueOf(tmp), dst) return true, Convert(reflect.ValueOf(tmp), dst)
} }
func DeserializeJSON(j map[string]string, target any) E.NestedError { func DeserializeJSON(j map[string]string, target any) E.Error {
data, err := E.Check(json.Marshal(j)) data, err := E.Check(json.Marshal(j))
if err != nil { if err != nil {
return err return err

View file

@ -113,7 +113,7 @@ type testType struct {
bar string bar string
} }
func (c *testType) ConvertFrom(v any) E.NestedError { func (c *testType) ConvertFrom(v any) E.Error {
switch v := v.(type) { switch v := v.(type) {
case string: case string:
c.bar = v c.bar = v

View file

@ -98,7 +98,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
return got.(T) return got.(T)
} }
func Must[T any](v T, err E.NestedError) T { func Must[T any](v T, err E.Error) T {
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -21,7 +21,7 @@ type DirWatcher struct {
mu sync.Mutex mu sync.Mutex
eventCh chan Event eventCh chan Event
errCh chan E.NestedError errCh chan E.Error
ctx context.Context ctx context.Context
} }
@ -48,14 +48,14 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher {
w: w, w: w,
fwMap: F.NewMapOf[string, *fileWatcher](), fwMap: F.NewMapOf[string, *fileWatcher](),
eventCh: make(chan Event), eventCh: make(chan Event),
errCh: make(chan E.NestedError), errCh: make(chan E.Error),
ctx: ctx, ctx: ctx,
} }
go helper.start() go helper.start()
return helper return helper
} }
func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) { func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.Error) {
return h.eventCh, h.errCh return h.eventCh, h.errCh
} }
@ -71,7 +71,7 @@ func (h *DirWatcher) Add(relPath string) Watcher {
s = &fileWatcher{ s = &fileWatcher{
relPath: relPath, relPath: relPath,
eventCh: make(chan Event), eventCh: make(chan Event),
errCh: make(chan E.NestedError), errCh: make(chan E.Error),
} }
go func() { go func() {
defer func() { defer func() {

View file

@ -36,6 +36,14 @@ var (
NewDockerFilter = filters.NewArgs NewDockerFilter = filters.NewArgs
optionsDefault = DockerListOptions{Filters: NewDockerFilter(
DockerFilterContainer,
DockerFilterStart,
// DockerFilterStop,
DockerFilterDie,
DockerFilterDestroy,
)}
dockerWatcherRetryInterval = 3 * time.Second dockerWatcherRetryInterval = 3 * time.Second
) )
@ -61,13 +69,13 @@ func NewDockerWatcherWithClient(client D.Client) DockerWatcher {
WithField("host", client.DaemonHost()))} WithField("host", client.DaemonHost()))}
} }
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) { func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) {
return w.EventsWithOptions(ctx, optionsWatchAll) return w.EventsWithOptions(ctx, optionsDefault)
} }
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) { func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.Error) {
eventCh := make(chan Event) eventCh := make(chan Event)
errCh := make(chan E.NestedError) errCh := make(chan E.Error)
go func() { go func() {
defer close(eventCh) defer close(eventCh)
@ -80,7 +88,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
}() }()
if !w.client.Connected() { if !w.client.Connected() {
var err E.NestedError var err E.Error
attempts := 0 attempts := 0
for { for {
w.client, err = D.ConnectClient(w.host) w.client, err = D.ConnectClient(w.host)
@ -141,11 +149,3 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
return eventCh, errCh return eventCh, errCh
} }
var optionsWatchAll = DockerListOptions{Filters: NewDockerFilter(
DockerFilterContainer,
DockerFilterStart,
// DockerFilterStop,
DockerFilterDie,
DockerFilterDestroy,
)}

View file

@ -3,20 +3,22 @@ package events
import ( import (
"time" "time"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
) )
type ( type (
EventQueue struct { EventQueue struct {
task task.Task task task.Task
queue []Event queue []Event
ticker *time.Ticker ticker *time.Ticker
onFlush OnFlushFunc flushInterval time.Duration
onError OnErrorFunc onFlush OnFlushFunc
onError OnErrorFunc
} }
OnFlushFunc = func(flushTask task.Task, events []Event) OnFlushFunc = func(flushTask task.Task, events []Event)
OnErrorFunc = func(err E.NestedError) OnErrorFunc = func(err E.Error)
) )
const eventQueueCapacity = 10 const eventQueueCapacity = 10
@ -35,40 +37,45 @@ const eventQueueCapacity = 10
// flushTask.Finish must be called after the flush is done, // flushTask.Finish must be called after the flush is done,
// but the onFlush function can return earlier (e.g. run in another goroutine). // but the onFlush function can return earlier (e.g. run in another goroutine).
// //
// If task is cancelled before the flushInterval is reached, the events in queue will be discarded. // If task is canceled before the flushInterval is reached, the events in queue will be discarded.
func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
return &EventQueue{ return &EventQueue{
task: parent.Subtask("event queue"), task: parent.Subtask("event queue"),
queue: make([]Event, 0, eventQueueCapacity), queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval), ticker: time.NewTicker(flushInterval),
onFlush: onFlush, flushInterval: flushInterval,
onError: onError, onFlush: onFlush,
onError: onError,
} }
} }
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) { func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
go func() { go func() {
defer e.ticker.Stop() defer e.ticker.Stop()
for { for {
select { select {
case <-e.task.Context().Done(): case <-e.task.Context().Done():
e.task.Finish(e.task.FinishCause().Error())
return return
case <-e.ticker.C: case <-e.ticker.C:
if len(e.queue) > 0 { if len(e.queue) > 0 {
flushTask := e.task.Subtask("flush events") flushTask := e.task.Subtask("flush events")
queue := e.queue queue := e.queue
e.queue = make([]Event, 0, eventQueueCapacity) e.queue = make([]Event, 0, eventQueueCapacity)
go func() { if !common.IsDebug {
defer func() { go func() {
if err := recover(); err != nil { defer func() {
e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name())) if err := recover(); err != nil {
} e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name()))
}
}()
e.onFlush(flushTask, queue)
}() }()
e.onFlush(flushTask, queue) } else {
}() go e.onFlush(flushTask, queue)
}
flushTask.Wait() flushTask.Wait()
} }
e.ticker.Reset(e.flushInterval)
case event, ok := <-eventCh: case event, ok := <-eventCh:
e.queue = append(e.queue, event) e.queue = append(e.queue, event)
if !ok { if !ok {

View file

@ -9,9 +9,9 @@ import (
type fileWatcher struct { type fileWatcher struct {
relPath string relPath string
eventCh chan Event eventCh chan Event
errCh chan E.NestedError errCh chan E.Error
} }
func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) { func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) {
return fw.eventCh, fw.errCh return fw.eventCh, fw.errCh
} }

View file

@ -0,0 +1,28 @@
package health
import (
"encoding/json"
"fmt"
"time"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/task"
)
type (
HealthMonitor interface {
task.TaskStarter
task.TaskFinisher
fmt.Stringer
json.Marshaler
Status() Status
Uptime() time.Duration
Name() string
}
HealthChecker interface {
CheckHealth() (healthy bool, detail string, err error)
URL() types.URL
Config() *HealthCheckConfig
UpdateURL(url types.URL)
}
)

View file

@ -2,9 +2,7 @@ package health
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt"
"time" "time"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -15,21 +13,6 @@ import (
) )
type ( type (
HealthMonitor interface {
task.TaskStarter
task.TaskFinisher
fmt.Stringer
json.Marshaler
Status() Status
Uptime() time.Duration
Name() string
}
HealthChecker interface {
CheckHealth() (healthy bool, detail string, err error)
URL() types.URL
Config() *HealthCheckConfig
UpdateURL(url types.URL)
}
HealthCheckFunc func() (healthy bool, detail string, err error) HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct { monitor struct {
service string service string
@ -71,7 +54,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
} }
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (mon *monitor) Start(routeSubtask task.Task) E.NestedError { func (mon *monitor) Start(routeSubtask task.Task) E.Error {
mon.service = routeSubtask.Parent().Name() mon.service = routeSubtask.Parent().Name()
mon.task = routeSubtask mon.task = routeSubtask
@ -84,7 +67,6 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
if mon.status.Load() != StatusError { if mon.status.Load() != StatusError {
mon.status.Store(StatusUnknown) mon.status.Store(StatusUnknown)
} }
mon.task.Finish(mon.task.FinishCause().Error())
}() }()
if err := mon.checkUpdateHealth(); err != nil { if err := mon.checkUpdateHealth(); err != nil {
@ -115,7 +97,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
} }
// Finish implements task.TaskFinisher. // Finish implements task.TaskFinisher.
func (mon *monitor) Finish(reason string) { func (mon *monitor) Finish(reason any) {
mon.task.Finish(reason) mon.task.Finish(reason)
} }
@ -169,10 +151,10 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
}).MarshalJSON() }).MarshalJSON()
} }
func (mon *monitor) checkUpdateHealth() E.NestedError { func (mon *monitor) checkUpdateHealth() E.Error {
healthy, detail, err := mon.checkHealth() healthy, detail, err := mon.checkHealth()
if err != nil { if err != nil {
defer mon.task.Finish(err.Error()) defer mon.task.Finish(err)
mon.status.Store(StatusError) mon.status.Store(StatusError)
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
return E.Failure("check health").With(err) return E.Failure("check health").With(err)

View file

@ -10,5 +10,5 @@ import (
type Event = events.Event type Event = events.Event
type Watcher interface { type Watcher interface {
Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) Events(ctx context.Context) (<-chan Event, <-chan E.Error)
} }