diff --git a/.gitignore b/.gitignore index 2f9becb..f89bbdf 100755 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ compose.yml *.compose.yml +config +certs config*/ certs*/ bin/ diff --git a/go.mod b/go.mod index 6425cee..1868514 100644 --- a/go.mod +++ b/go.mod @@ -41,12 +41,12 @@ require ( github.com/ovh/go-ovh v1.6.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect - go.opentelemetry.io/otel v1.30.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect + go.opentelemetry.io/otel v1.31.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect - go.opentelemetry.io/otel/metric v1.30.0 // indirect + go.opentelemetry.io/otel/metric v1.31.0 // indirect go.opentelemetry.io/otel/sdk v1.30.0 // indirect - go.opentelemetry.io/otel/trace v1.30.0 // indirect + go.opentelemetry.io/otel/trace v1.31.0 // indirect golang.org/x/crypto v0.28.0 // indirect golang.org/x/mod v0.21.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect diff --git a/go.sum b/go.sum index 3c4fe09..958c6d3 100644 --- a/go.sum +++ b/go.sum @@ -96,20 +96,20 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 h1:ZIg3ZT/aQ7AfKqdwp7ECpOK6vHqquXXuyTjIO8ZdmPs= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0/go.mod h1:DQAwmETtZV00skUwgD6+0U89g80NKsJE3DCKeLLPQMI= -go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts= -go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM= +go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= +go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0 h1:lsInsfvhVIfOI6qHVyysXMNDnjO9Npvl7tlDPJFBVd4= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0/go.mod h1:KQsVNh4OjgjTG0G6EiNi1jVpnaeeKsKMRwbLN+f1+8M= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 h1:umZgi92IyxfXd/l4kaDhnKgY8rnN/cZcF1LKc6I8OQ8= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0/go.mod h1:4lVs6obhSVRb1EW5FhOuBTyiQhtRtAnnva9vD3yRfq8= -go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w= -go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ= +go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= +go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= go.opentelemetry.io/otel/sdk v1.30.0 h1:cHdik6irO49R5IysVhdn8oaiR9m8XluDaJAs4DfOrYE= go.opentelemetry.io/otel/sdk v1.30.0/go.mod h1:p14X4Ok8S+sygzblytT1nqG98QG2KYKv++HE0LY/mhg= -go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc= -go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o= +go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= +go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/internal/api/v1/checkhealth.go b/internal/api/v1/checkhealth.go index 453ee02..be0de54 100644 --- a/internal/api/v1/checkhealth.go +++ b/internal/api/v1/checkhealth.go @@ -15,10 +15,15 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { return } - status, ok := health.Inspect(target) + result, ok := health.Inspect(target) if !ok { HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound) return } - WriteBody(w, []byte(status.String())) + json, err := result.MarshalJSON() + if err != nil { + HandleErr(w, r, err) + return + } + RespondJSON(w, r, json) } diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 66a4076..86208a8 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -8,6 +8,7 @@ import ( "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/utils" ) @@ -45,16 +46,7 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) { } func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - routes := cfg.RoutesByAlias() - typeFilter := r.FormValue("type") - if typeFilter != "" { - for k, v := range routes { - if v["type"] != typeFilter { - delete(routes, k) - } - } - } - + routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type"))) U.RespondJSON(w, r, routes) } diff --git a/internal/common/task.go b/internal/common/task.go index 6e0e61b..dd55a6c 100644 --- a/internal/common/task.go +++ b/internal/common/task.go @@ -212,9 +212,9 @@ func GlobalContextWait(timeout time.Duration) { case <-done: return case <-after: - logrus.Println("Timeout waiting for these tasks to finish:") + logrus.Warnln("Timeout waiting for these tasks to finish:") tasksMap.Range(func(t *task, _ struct{}) bool { - logrus.Println(t.tree()) + logrus.Warnln(t.tree()) return true }) return diff --git a/internal/config/config.go b/internal/config/config.go index 444e0f8..bc6fcfc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -87,9 +87,8 @@ func (cfg *Config) StartProxyProviders() { func (cfg *Config) WatchChanges() { task := common.NewTask("Config watcher") - defer task.Finished() - go func() { + defer task.Finished() for { select { case <-task.Context().Done(): diff --git a/internal/config/query.go b/internal/config/query.go index 9d04469..cd77ef9 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -42,25 +42,18 @@ func (cfg *Config) HomepageConfig() homepage.Config { } hpCfg := homepage.NewHomePageConfig() - cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { - if !r.Started() { - return - } - - entry := r.Entry - if entry.Homepage == nil { - entry.Homepage = &homepage.Item{ - Show: r.Entry.IsExplicit || !p.IsExplicitOnly(), - } - } - + R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { + entry := r.Raw item := entry.Homepage + if item == nil { + item = new(homepage.Item) + } - if !item.Show && !item.IsEmpty() { + if !item.Show && item.IsEmpty() { item.Show = true } - if !item.Show || r.Type != R.RouteTypeReverseProxy { + if !item.Show { return } @@ -73,12 +66,17 @@ func (cfg *Config) HomepageConfig() homepage.Config { ) } - if p.GetType() == PR.ProviderTypeDocker { + if r.IsDocker() { if item.Category == "" { item.Category = "Docker" } item.SourceType = string(PR.ProviderTypeDocker) - } else if p.GetType() == PR.ProviderTypeFile { + } else if r.UseLoadBalance() { + if item.Category == "" { + item.Category = "Load-balanced" + } + item.SourceType = "loadbalancer" + } else { if item.Category == "" { item.Category = "Others" } @@ -97,28 +95,20 @@ func (cfg *Config) HomepageConfig() homepage.Config { return hpCfg } -func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]U.SerializedObject { - routes := make(map[string]U.SerializedObject) - if len(typeFilter) == 0 { +func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any { + routes := make(map[string]any) + if len(typeFilter) == 0 || typeFilter[0] == "" { typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream} } for _, t := range typeFilter { switch t { case R.RouteTypeReverseProxy: R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { - obj, err := U.Serialize(r) - if err != nil { - panic(err) // should not happen - } - routes[alias] = obj + routes[alias] = r }) case R.RouteTypeStream: R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) { - obj, err := U.Serialize(r) - if err != nil { - panic(err) // should not happen - } - routes[alias] = obj + routes[alias] = r }) } } diff --git a/internal/docker/client.go b/internal/docker/client.go index a07a27e..c8d9941 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -43,7 +43,9 @@ func init() { select { case <-task.Context().Done(): clientMap.RangeAllParallel(func(_ string, c Client) { - c.Client.Close() + if c.Connected() { + c.Client.Close() + } }) clientMap.Clear() return diff --git a/internal/docker/container.go b/internal/docker/container.go index 5f7e5f6..115e520 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -41,6 +41,8 @@ type ( } ) +var DummyContainer = new(Container) + func FromDocker(c *types.Container, dockerHost string) (res *Container) { isExplicit := c.Labels[LabelAliases] != "" helper := containerHelper{c} diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 466768a..fa0c4e4 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -2,7 +2,6 @@ package idlewatcher import ( "context" - "encoding/json" "net/http" "strconv" "time" @@ -73,15 +72,15 @@ func (w *Waker) Uptime() time.Duration { } func (w *Waker) MarshalJSON() ([]byte, error) { - return json.Marshal(map[string]any{ - "name": w.Name(), - "url": w.URL, - "status": w.Status(), - "config": health.HealthCheckConfig{ + return (&health.JSONRepresentation{ + Name: w.Name(), + Status: w.Status(), + Config: &health.HealthCheckConfig{ Interval: w.IdleTimeout, Timeout: w.WakeTimeout, }, - }) + URL: w.URL, + }).MarshalJSON() } /* End of HealthMonitor interface */ @@ -89,6 +88,10 @@ func (w *Waker) MarshalJSON() ([]byte, error) { func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { w.resetIdleTimer() + if r.Body != nil { + defer r.Body.Close() + } + // pass through if container is ready if w.ready.Load() { return true @@ -115,6 +118,16 @@ func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) return } + select { + case <-w.task.Context().Done(): + http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) + return + case <-ctx.Done(): + http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) + return + default: + } + // wake the container and reset idle timer // also wait for another wake request w.wakeCh <- struct{}{} @@ -169,3 +182,8 @@ func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) time.Sleep(100 * time.Millisecond) } } + +// static HealthMonitor interface check +func (w *Waker) _() health.HealthMonitor { + return w +} diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index 98c0ec8..b3f3776 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -8,10 +8,12 @@ import ( "github.com/docker/docker/api/types/container" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" P "github.com/yusing/go-proxy/internal/proxy" PT "github.com/yusing/go-proxy/internal/proxy/fields" + U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" W "github.com/yusing/go-proxy/internal/watcher" ) @@ -29,9 +31,10 @@ type ( wakeDone chan E.NestedError ticker *time.Ticker - ctx context.Context - cancel context.CancelFunc - refCount *sync.WaitGroup + task common.Task + cancel context.CancelFunc + + refCount *U.RefCount l logrus.FieldLogger } @@ -42,17 +45,11 @@ type ( ) var ( - mainLoopCtx context.Context - mainLoopCancel context.CancelFunc - mainLoopWg sync.WaitGroup - watcherMap = F.NewMapOf[string, *Watcher]() watcherMapMu sync.Mutex portHistoryMap = F.NewMapOf[PT.Alias, string]() - newWatcherCh = make(chan *Watcher) - logger = logrus.WithField("module", "idle_watcher") ) @@ -73,7 +70,7 @@ func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) { } if w, ok := watcherMap.Load(key); ok { - w.refCount.Add(1) + w.refCount.Add() w.ReverseProxyEntry = entry return w, nil } @@ -86,83 +83,51 @@ func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) { w := &Watcher{ ReverseProxyEntry: entry, client: client, - refCount: &sync.WaitGroup{}, + refCount: U.NewRefCounter(), wakeCh: make(chan struct{}, 1), wakeDone: make(chan E.NestedError), ticker: time.NewTicker(entry.IdleTimeout), l: logger.WithField("container", entry.ContainerName), } - w.refCount.Add(1) + w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias) w.stopByMethod = w.getStopCallback() watcherMap.Store(key, w) - go func() { - newWatcherCh <- w - }() + go w.watchUntilCancel() return w, nil } func (w *Watcher) Unregister() { - w.refCount.Add(-1) -} - -func Start() { - logger.Debug("started") - defer logger.Debug("stopped") - - mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background()) - - for { - select { - case <-mainLoopCtx.Done(): - return - case w := <-newWatcherCh: - w.l.Debug("registered") - mainLoopWg.Add(1) - go func() { - w.watchUntilCancel() - w.refCount.Wait() // wait for 0 ref count - - watcherMap.Delete(w.ContainerID) - w.l.Debug("unregistered") - mainLoopWg.Done() - }() - } - } -} - -func Stop() { - mainLoopCancel() - mainLoopWg.Wait() + w.refCount.Sub() } func (w *Watcher) containerStop() error { - return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{ + return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{ Signal: string(w.StopSignal), Timeout: &w.StopTimeout, }) } func (w *Watcher) containerPause() error { - return w.client.ContainerPause(w.ctx, w.ContainerID) + return w.client.ContainerPause(w.task.Context(), w.ContainerID) } func (w *Watcher) containerKill() error { - return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal)) + return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal)) } func (w *Watcher) containerUnpause() error { - return w.client.ContainerUnpause(w.ctx, w.ContainerID) + return w.client.ContainerUnpause(w.task.Context(), w.ContainerID) } func (w *Watcher) containerStart() error { - return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{}) + return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{}) } func (w *Watcher) containerStatus() (string, E.NestedError) { - json, err := w.client.ContainerInspect(w.ctx, w.ContainerID) + json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID) if err != nil { return "", E.FailWith("inspect container", err) } @@ -221,12 +186,8 @@ func (w *Watcher) resetIdleTimer() { } func (w *Watcher) watchUntilCancel() { - defer close(w.wakeCh) - - w.ctx, w.cancel = context.WithCancel(mainLoopCtx) - dockerWatcher := W.NewDockerWatcherWithClient(w.client) - dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{ + dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{ Filters: W.NewDockerFilter( W.DockerFilterContainer, W.DockerrFilterContainer(w.ContainerID), @@ -238,13 +199,23 @@ func (w *Watcher) watchUntilCancel() { W.DockerFilterUnpause, ), }) - defer w.ticker.Stop() - defer w.client.Close() + + defer func() { + w.ticker.Stop() + w.client.Close() + close(w.wakeDone) + close(w.wakeCh) + watcherMap.Delete(w.ContainerID) + w.task.Finished() + }() for { select { - case <-w.ctx.Done(): - w.l.Debug("stopped") + case <-w.task.Context().Done(): + w.l.Debug("stopped by context done") + return + case <-w.refCount.Zero(): + w.l.Debug("stopped by zero ref count") return case err := <-dockerEventErrCh: if err != nil && err.IsNot(context.Canceled) { diff --git a/internal/docker/inspect.go b/internal/docker/inspect.go index fcafe77..ae277ee 100644 --- a/internal/docker/inspect.go +++ b/internal/docker/inspect.go @@ -7,6 +7,17 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) +func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) { + client, err := ConnectClient(dockerHost) + defer client.Close() + + if err.HasError() { + return nil, E.FailWith("connect to docker", err) + } + + return client.Inspect(containerID) +} + func (c Client) Inspect(containerID string) (*Container, E.NestedError) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() diff --git a/internal/error/builder_test.go b/internal/error/builder_test.go index d4bc6ec..0f3c613 100644 --- a/internal/error/builder_test.go +++ b/internal/error/builder_test.go @@ -46,7 +46,18 @@ func TestBuilderNested(t *testing.T) { - invalid Inner: "2" - Action 2 failed: - invalid Inner: "3"`) - if got != expected1 && got != expected2 { - t.Errorf("expected \n%s, got \n%s", expected1, got) - } + ExpectEqualAny(t, got, []string{expected1, expected2}) +} + +func TestBuilderTo(t *testing.T) { + eb := NewBuilder("error occurred") + eb.Addf("abcd") + + var err NestedError + eb.To(&err) + got := err.String() + expected := (`error occurred: + - abcd`) + + ExpectEqual(t, got, expected) } diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index 3a86bfd..0d3bee1 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -6,8 +6,8 @@ import ( "time" "github.com/go-acme/lego/v4/log" - E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/watcher/health" ) // TODO: stats of each server. @@ -41,13 +41,14 @@ type ( const maxWeight weightType = 100 func New(cfg *Config) *LoadBalancer { - lb := &LoadBalancer{Config: cfg, pool: servers{}} - mode := cfg.Mode - if !cfg.Mode.ValidateUpdate() { - logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %s", cfg.Link, mode, cfg.Mode) - } - switch mode { - case RoundRobin: + lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)} + lb.UpdateConfigIfNeeded(cfg) + return lb +} + +func (lb *LoadBalancer) updateImpl() { + switch lb.Mode { + case Unset, RoundRobin: lb.impl = lb.newRoundRobin() case LeastConn: lb.impl = lb.newLeastConn() @@ -56,7 +57,34 @@ func New(cfg *Config) *LoadBalancer { default: // should happen in test only lb.impl = lb.newRoundRobin() } - return lb + for _, srv := range lb.pool { + lb.impl.OnAddServer(srv) + } +} + +func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { + if cfg != nil { + lb.poolMu.Lock() + defer lb.poolMu.Unlock() + + lb.Link = cfg.Link + + if lb.Mode == Unset && cfg.Mode != Unset { + lb.Mode = cfg.Mode + if !lb.Mode.ValidateUpdate() { + logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode) + } + lb.updateImpl() + } + + if len(lb.Options) == 0 && len(cfg.Options) > 0 { + lb.Options = cfg.Options + } + } + + if lb.impl == nil { + lb.updateImpl() + } } func (lb *LoadBalancer) AddServer(srv *Server) { @@ -66,6 +94,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) { lb.pool = append(lb.pool, srv) lb.sumWeight += srv.Weight + lb.Rebalance() lb.impl.OnAddServer(srv) logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool)) } @@ -74,6 +103,8 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() + lb.sumWeight -= srv.Weight + lb.Rebalance() lb.impl.OnRemoveServer(srv) for i, s := range lb.pool { @@ -87,7 +118,6 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { return } - lb.Rebalance() logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool)) } @@ -152,20 +182,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } func (lb *LoadBalancer) Start() { - if lb.sumWeight != 0 && lb.sumWeight != maxWeight { - msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight) - for _, s := range lb.pool { - msg.Addf("%s: %d", s.Name, s.Weight) - } - lb.Rebalance() - inner := E.NewBuilder("after rebalancing") - for _, s := range lb.pool { - inner.Addf("%s: %d", s.Name, s.Weight) - } - msg.Addf("%s", inner) - logger.Warn(msg) - } - if lb.sumWeight != 0 { log.Warnf("weighted mode not supported yet") } @@ -186,6 +202,45 @@ func (lb *LoadBalancer) Uptime() time.Duration { return time.Since(lb.startTime) } +// MarshalJSON implements health.HealthMonitor. +func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { + extra := make(map[string]any) + for _, v := range lb.pool { + extra[v.Name] = v.healthMon + } + return (&health.JSONRepresentation{ + Name: lb.Name(), + Status: lb.Status(), + Started: lb.startTime, + Uptime: lb.Uptime(), + Extra: map[string]any{ + "config": lb.Config, + "pool": extra, + }, + }).MarshalJSON() +} + +// Name implements health.HealthMonitor. +func (lb *LoadBalancer) Name() string { + return lb.Link +} + +// Status implements health.HealthMonitor. +func (lb *LoadBalancer) Status() health.Status { + if len(lb.pool) == 0 { + return health.StatusUnknown + } + if len(lb.availServers()) == 0 { + return health.StatusUnhealthy + } + return health.StatusHealthy +} + +// String implements health.HealthMonitor. +func (lb *LoadBalancer) String() string { + return lb.Name() +} + func (lb *LoadBalancer) availServers() servers { lb.poolMu.Lock() defer lb.poolMu.Unlock() @@ -199,3 +254,8 @@ func (lb *LoadBalancer) availServers() servers { } return avail } + +// static HealthMonitor interface check +func (lb *LoadBalancer) _() health.HealthMonitor { + return lb +} diff --git a/internal/net/http/loadbalancer/mode.go b/internal/net/http/loadbalancer/mode.go index 9d6f91d..919d311 100644 --- a/internal/net/http/loadbalancer/mode.go +++ b/internal/net/http/loadbalancer/mode.go @@ -7,6 +7,7 @@ import ( type Mode string const ( + Unset Mode = "" RoundRobin Mode = "roundrobin" LeastConn Mode = "leastconn" IPHash Mode = "iphash" @@ -14,7 +15,9 @@ const ( func (mode *Mode) ValidateUpdate() bool { switch U.ToLowerNoSnake(string(*mode)) { - case "", string(RoundRobin): + case "": + return true + case string(RoundRobin): *mode = RoundRobin return true case string(LeastConn): diff --git a/internal/proxy/entry.go b/internal/proxy/entry.go index 97bbb12..08979e0 100644 --- a/internal/proxy/entry.go +++ b/internal/proxy/entry.go @@ -16,32 +16,36 @@ import ( type ( ReverseProxyEntry struct { // real model after validation - Alias T.Alias `json:"alias"` - Scheme T.Scheme `json:"scheme"` - URL net.URL `json:"url"` + Raw *types.RawEntry `json:"raw"` + + Alias T.Alias `json:"alias,omitempty"` + Scheme T.Scheme `json:"scheme,omitempty"` + URL net.URL `json:"url,omitempty"` NoTLSVerify bool `json:"no_tls_verify,omitempty"` - PathPatterns T.PathPatterns `json:"path_patterns"` - HealthCheck *health.HealthCheckConfig `json:"healthcheck"` + PathPatterns T.PathPatterns `json:"path_patterns,omitempty"` + HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"` Middlewares D.NestedLabelMap `json:"middlewares,omitempty"` /* Docker only */ - IdleTimeout time.Duration `json:"idle_timeout"` - WakeTimeout time.Duration `json:"wake_timeout"` - StopMethod T.StopMethod `json:"stop_method"` - StopTimeout int `json:"stop_timeout"` + IdleTimeout time.Duration `json:"idle_timeout,omitempty"` + WakeTimeout time.Duration `json:"wake_timeout,omitempty"` + StopMethod T.StopMethod `json:"stop_method,omitempty"` + StopTimeout int `json:"stop_timeout,omitempty"` StopSignal T.Signal `json:"stop_signal,omitempty"` - DockerHost string `json:"docker_host"` - ContainerName string `json:"container_name"` - ContainerID string `json:"container_id"` - ContainerRunning bool `json:"container_running"` + DockerHost string `json:"docker_host,omitempty"` + ContainerName string `json:"container_name,omitempty"` + ContainerID string `json:"container_id,omitempty"` + ContainerRunning bool `json:"container_running,omitempty"` } StreamEntry struct { - Alias T.Alias `json:"alias"` - Scheme T.StreamScheme `json:"scheme"` - Host T.Host `json:"host"` - Port T.StreamPort `json:"port"` - Healthcheck *health.HealthCheckConfig `json:"healthcheck"` + Raw *types.RawEntry `json:"raw"` + + Alias T.Alias `json:"alias,omitempty"` + Scheme T.StreamScheme `json:"scheme,omitempty"` + Host T.Host `json:"host,omitempty"` + Port T.StreamPort `json:"port,omitempty"` + Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"` } ) @@ -88,6 +92,10 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) { func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry { var stopTimeOut time.Duration + cont := m.Container + if cont == nil { + cont = D.DummyContainer + } host, err := T.ValidateHost(m.Host) b.Add(err) @@ -101,21 +109,21 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) b.Add(err) - idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout) + idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout) b.Add(err) - wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout) + wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout) b.Add(err) - stopMethod, err := T.ValidateStopMethod(m.StopMethod) + stopMethod, err := T.ValidateStopMethod(cont.StopMethod) b.Add(err) if stopMethod == T.StopMethodStop { - stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout) + stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout) b.Add(err) } - stopSignal, err := T.ValidateSignal(m.StopSignal) + stopSignal, err := T.ValidateSignal(cont.StopSignal) b.Add(err) if err != nil { @@ -123,6 +131,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn } return &ReverseProxyEntry{ + Raw: m, Alias: T.NewAlias(m.Alias), Scheme: s, URL: net.NewURL(url), @@ -136,10 +145,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn StopMethod: stopMethod, StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument StopSignal: stopSignal, - DockerHost: m.DockerHost, - ContainerName: m.ContainerName, - ContainerID: m.ContainerID, - ContainerRunning: m.Running, + DockerHost: cont.DockerHost, + ContainerName: cont.ContainerName, + ContainerID: cont.ContainerID, + ContainerRunning: cont.Running, } } @@ -158,6 +167,7 @@ func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry { } return &StreamEntry{ + Raw: m, Alias: T.NewAlias(m.Alias), Scheme: *scheme, Host: host, diff --git a/internal/proxy/fields/stream_scheme.go b/internal/proxy/fields/stream_scheme.go index 6768460..17835db 100644 --- a/internal/proxy/fields/stream_scheme.go +++ b/internal/proxy/fields/stream_scheme.go @@ -32,7 +32,7 @@ func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { } func (s StreamScheme) String() string { - return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme) + return fmt.Sprintf("%s -> %s", s.ListeningScheme, s.ProxyScheme) } // IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal. diff --git a/internal/proxy/provider/docker.go b/internal/proxy/provider/docker.go index f493c06..50fae90 100755 --- a/internal/proxy/provider/docker.go +++ b/internal/proxy/provider/docker.go @@ -72,7 +72,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { } entries.RangeAll(func(_ string, e *types.RawEntry) { - e.DockerHost = p.dockerHost + e.Container.DockerHost = p.dockerHost }) routes, err = R.FromEntries(entries) @@ -88,7 +88,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool { strings.HasSuffix(container.ContainerName, "-old") } -func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) { +func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) { switch event.Action { case events.ActionContainerStart, events.ActionContainerStop: break @@ -98,75 +98,66 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul b := E.NewBuilder("event %s error", event) defer b.To(&res.err) - routes.RangeAll(func(k string, v *R.Route) { - if v.Entry.ContainerID == event.ActorID || - v.Entry.ContainerName == event.ActorName { + matches := R.NewRoutes() + oldRoutes.RangeAllParallel(func(k string, v *R.Route) { + if v.Entry.Container.ContainerID == event.ActorID || + v.Entry.Container.ContainerName == event.ActorName { + matches.Store(k, v) + } + }) + + var newRoutes R.Routes + var err E.NestedError + + if matches.Size() == 0 { // id & container name changed + matches = oldRoutes + newRoutes, err = p.LoadRoutesImpl() + b.Add(err) + } else { + cont, err := D.Inspect(p.dockerHost, event.ActorID) + if err != nil { + b.Add(E.FailWith("inspect container", err)) + return + } + + if p.shouldIgnore(cont) { + // stop all old routes + matches.RangeAllParallel(func(_ string, v *R.Route) { + b.Add(v.Stop()) + }) + return + } + + entries, err := p.entriesFromContainerLabels(cont) + b.Add(err) + newRoutes, err = R.FromEntries(entries) + b.Add(err) + } + + matches.RangeAll(func(k string, v *R.Route) { + if !newRoutes.Has(k) && !oldRoutes.Has(k) { b.Add(v.Stop()) - routes.Delete(k) + matches.Delete(k) res.nRemoved++ } }) - if res.nRemoved == 0 { // id & container name changed - // load all routes (rescan) - routesNew, err := p.LoadRoutesImpl() - routesOld := routes - if routesNew.Size() == 0 { - b.Add(E.FailWith("rescan routes", err)) - return - } - routesNew.Range(func(k string, v *R.Route) bool { - if !routesOld.Has(k) { - routesOld.Store(k, v) - b.Add(v.Start()) - res.nAdded++ - return false - } - return true - }) - routesOld.Range(func(k string, v *R.Route) bool { - if !routesNew.Has(k) { - b.Add(v.Stop()) - routesOld.Delete(k) - res.nRemoved++ - return false - } - return true - }) - return - } - - client, err := D.ConnectClient(p.dockerHost) - if err != nil { - b.Add(E.FailWith("connect to docker", err)) - return - } - defer client.Close() - cont, err := client.Inspect(event.ActorID) - if err != nil { - b.Add(E.FailWith("inspect container", err)) - return - } - - if p.shouldIgnore(cont) { - return - } - - entries, err := p.entriesFromContainerLabels(cont) - b.Add(err) - - entries.RangeAll(func(alias string, entry *types.RawEntry) { - if routes.Has(alias) { - b.Add(E.Duplicated("alias", alias)) - } else { - if route, err := R.NewRoute(entry); err != nil { + newRoutes.RangeAll(func(alias string, newRoute *R.Route) { + oldRoute, exists := oldRoutes.Load(alias) + if exists { + if err := oldRoute.Stop(); err != nil { b.Add(err) - } else { - routes.Store(alias, route) - b.Add(route.Start()) - res.nAdded++ } } + oldRoutes.Store(alias, newRoute) + if err := newRoute.Start(); err != nil { + b.Add(err) + } + if exists { + res.nReloaded++ + } else { + res.nAdded++ + } }) return diff --git a/internal/proxy/provider/docker_test.go b/internal/proxy/provider/docker_test.go index 90a424f..89959f0 100644 --- a/internal/proxy/provider/docker_test.go +++ b/internal/proxy/provider/docker_test.go @@ -88,20 +88,20 @@ func TestApplyLabelWildcard(t *testing.T) { ExpectDeepEqual(t, a.Middlewares, middlewaresExpect) ExpectEqual(t, len(b.Middlewares), 0) - ExpectEqual(t, a.IdleTimeout, common.IdleTimeoutDefault) - ExpectEqual(t, b.IdleTimeout, common.IdleTimeoutDefault) + ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault) + ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault) - ExpectEqual(t, a.StopTimeout, common.StopTimeoutDefault) - ExpectEqual(t, b.StopTimeout, common.StopTimeoutDefault) + ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault) + ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault) - ExpectEqual(t, a.StopMethod, common.StopMethodDefault) - ExpectEqual(t, b.StopMethod, common.StopMethodDefault) + ExpectEqual(t, a.Container.StopMethod, common.StopMethodDefault) + ExpectEqual(t, b.Container.StopMethod, common.StopMethodDefault) - ExpectEqual(t, a.WakeTimeout, common.WakeTimeoutDefault) - ExpectEqual(t, b.WakeTimeout, common.WakeTimeoutDefault) + ExpectEqual(t, a.Container.WakeTimeout, common.WakeTimeoutDefault) + ExpectEqual(t, b.Container.WakeTimeout, common.WakeTimeoutDefault) - ExpectEqual(t, a.StopSignal, "SIGTERM") - ExpectEqual(t, b.StopSignal, "SIGTERM") + ExpectEqual(t, a.Container.StopSignal, "SIGTERM") + ExpectEqual(t, b.Container.StopSignal, "SIGTERM") } func TestApplyLabelWithAlias(t *testing.T) { @@ -186,16 +186,16 @@ func TestPublicIPLocalhost(t *testing.T) { c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost) raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - ExpectEqual(t, raw.PublicIP, "127.0.0.1") - ExpectEqual(t, raw.Host, raw.PublicIP) + ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1") + ExpectEqual(t, raw.Host, raw.Container.PublicIP) } func TestPublicIPRemote(t *testing.T) { c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - ExpectEqual(t, raw.PublicIP, "1.2.3.4") - ExpectEqual(t, raw.Host, raw.PublicIP) + ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") + ExpectEqual(t, raw.Host, raw.Container.PublicIP) } func TestPrivateIPLocalhost(t *testing.T) { @@ -211,8 +211,8 @@ func TestPrivateIPLocalhost(t *testing.T) { }, client.DefaultDockerHost) raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - ExpectEqual(t, raw.PrivateIP, "172.17.0.123") - ExpectEqual(t, raw.Host, raw.PrivateIP) + ExpectEqual(t, raw.Container.PrivateIP, "172.17.0.123") + ExpectEqual(t, raw.Host, raw.Container.PrivateIP) } func TestPrivateIPRemote(t *testing.T) { @@ -228,9 +228,9 @@ func TestPrivateIPRemote(t *testing.T) { }, "tcp://1.2.3.4:2375") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - ExpectEqual(t, raw.PrivateIP, "") - ExpectEqual(t, raw.PublicIP, "1.2.3.4") - ExpectEqual(t, raw.Host, raw.PublicIP) + ExpectEqual(t, raw.Container.PrivateIP, "") + ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") + ExpectEqual(t, raw.Host, raw.Container.PublicIP) } func TestStreamDefaultValues(t *testing.T) { diff --git a/internal/proxy/provider/provider.go b/internal/proxy/provider/provider.go index 62e2f76..2407efc 100644 --- a/internal/proxy/provider/provider.go +++ b/internal/proxy/provider/provider.go @@ -5,6 +5,7 @@ import ( "path" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" R "github.com/yusing/go-proxy/internal/route" W "github.com/yusing/go-proxy/internal/watcher" @@ -19,7 +20,7 @@ type ( routes R.Routes watcher W.Watcher - watcherCtx context.Context + watcherTask common.Task watcherCancel context.CancelFunc l *logrus.Entry @@ -38,9 +39,10 @@ type ( Type ProviderType `json:"type"` } EventResult struct { - nRemoved int - nAdded int - err E.NestedError + nAdded int + nRemoved int + nReloaded int + err E.NestedError } ) @@ -129,6 +131,7 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) { p.routes.RangeAllParallel(func(alias string, r *R.Route) { errors.Add(r.Stop().Subject(r)) }) + p.routes.Clear() return } @@ -175,27 +178,21 @@ func (p *Provider) Statistics() ProviderStats { } func (p *Provider) watchEvents() { - p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) - events, errs := p.watcher.Events(p.watcherCtx) + p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name) + defer p.watcherTask.Finished() + + events, errs := p.watcher.Events(p.watcherTask.Context()) l := p.l.WithField("module", "watcher") for { select { - case <-p.watcherCtx.Done(): + case <-p.watcherTask.Context().Done(): return case event := <-events: res := p.OnEvent(event, p.routes) - l.Infof("%s event %q", event.Type, event) - if res.nAdded > 0 || res.nRemoved > 0 { - n := res.nAdded - res.nRemoved - switch { - case n == 0: - l.Infof("%d route(s) reloaded", res.nAdded) - case n > 0: - l.Infof("%d route(s) added", n) - default: - l.Infof("%d route(s) removed", -n) - } + if res.nAdded+res.nRemoved+res.nReloaded > 0 { + l.Infof("%s event %q", event.Type, event) + l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded) } if res.err != nil { l.Error(res.err) diff --git a/internal/route/http.go b/internal/route/http.go index 904b8cc..79c6386 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -18,6 +18,7 @@ import ( url "github.com/yusing/go-proxy/internal/net/types" P "github.com/yusing/go-proxy/internal/proxy" PT "github.com/yusing/go-proxy/internal/proxy/fields" + "github.com/yusing/go-proxy/internal/types" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -26,12 +27,12 @@ type ( HTTPRoute struct { *P.ReverseProxyEntry `json:"entry"` - LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"` - HealthMon health.HealthMonitor `json:"health"` + HealthMon health.HealthMonitor `json:"health,omitempty"` - server *loadbalancer.Server - handler http.Handler - rp *gphttp.ReverseProxy + loadBalancer *loadbalancer.LoadBalancer + server *loadbalancer.Server + handler http.Handler + rp *gphttp.ReverseProxy } SubdomainKey = PT.Alias @@ -102,10 +103,6 @@ func (r *HTTPRoute) URL() url.URL { } func (r *HTTPRoute) Start() E.NestedError { - if r.handler != nil { - return nil - } - if r.ShouldNotServe() { return nil } @@ -113,6 +110,10 @@ func (r *HTTPRoute) Start() E.NestedError { httpRoutesMu.Lock() defer httpRoutesMu.Unlock() + if r.handler != nil { + return nil + } + if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) { logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) r.HealthCheck.Disabled = true @@ -129,15 +130,23 @@ func (r *HTTPRoute) Start() E.NestedError { r.HealthMon = waker case !r.HealthCheck.Disabled: r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask("Reverse proxy "+r.String()), r.URL(), r.HealthCheck) - fallthrough - case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/": - r.handler = ReverseProxyHandler{r.rp} - default: - mux := http.NewServeMux() - for _, p := range r.PathPatterns { - mux.HandleFunc(string(p), r.rp.ServeHTTP) + } + + if r.handler == nil { + switch { + case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/": + r.handler = ReverseProxyHandler{r.rp} + default: + mux := http.NewServeMux() + for _, p := range r.PathPatterns { + mux.HandleFunc(string(p), r.rp.ServeHTTP) + } + r.handler = mux } - r.handler = mux + } + + if r.HealthMon != nil { + r.HealthMon.Start() } if r.UseLoadBalance() { @@ -146,9 +155,6 @@ func (r *HTTPRoute) Start() E.NestedError { httpRoutes.Store(string(r.Alias), r) } - if r.HealthMon != nil { - r.HealthMon.Start() - } return nil } @@ -160,7 +166,7 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) { httpRoutesMu.Lock() defer httpRoutesMu.Unlock() - if r.LoadBalancer != nil { + if r.loadBalancer != nil { r.removeFromLoadBalancer() } else { httpRoutes.Delete(string(r.Alias)) @@ -184,29 +190,40 @@ func (r *HTTPRoute) addToLoadBalancer() { var lb *loadbalancer.LoadBalancer linked, ok := httpRoutes.Load(r.LoadBalance.Link) if ok { - lb = linked.LoadBalancer + lb = linked.loadBalancer + lb.UpdateConfigIfNeeded(r.LoadBalance) + if linked.Raw.Homepage == nil && r.Raw.Homepage != nil { + linked.Raw.Homepage = r.Raw.Homepage + } } else { lb = loadbalancer.New(r.LoadBalance) lb.Start() linked = &HTTPRoute{ - LoadBalancer: lb, + ReverseProxyEntry: &P.ReverseProxyEntry{ + Raw: &types.RawEntry{ + Homepage: r.Raw.Homepage, + }, + Alias: PT.Alias(lb.Link), + }, + HealthMon: lb, + loadBalancer: lb, handler: lb, } httpRoutes.Store(r.LoadBalance.Link, linked) } - r.LoadBalancer = lb + r.loadBalancer = lb r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) lb.AddServer(r.server) } func (r *HTTPRoute) removeFromLoadBalancer() { - r.LoadBalancer.RemoveServer(r.server) - if r.LoadBalancer.IsEmpty() { + r.loadBalancer.RemoveServer(r.server) + if r.loadBalancer.IsEmpty() { httpRoutes.Delete(r.LoadBalance.Link) logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link) } r.server = nil - r.LoadBalancer = nil + r.loadBalancer = nil } func ProxyHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/route/route.go b/internal/route/route.go index c718023..58d3782 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -1,6 +1,7 @@ package route import ( + "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" url "github.com/yusing/go-proxy/internal/net/types" P "github.com/yusing/go-proxy/internal/proxy" @@ -36,6 +37,13 @@ const ( // function alias. var NewRoutes = F.NewMap[Routes] +func (rt *Route) Container() *docker.Container { + if rt.Entry.Container == nil { + return docker.DummyContainer + } + return rt.Entry.Container +} + func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { entry, err := P.ValidateEntry(en) if err != nil { diff --git a/internal/route/stream.go b/internal/route/stream.go index aef8c0f..1d9c38b 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "sync" "github.com/sirupsen/logrus" @@ -24,12 +25,11 @@ type StreamRoute struct { url url.URL - wg sync.WaitGroup task common.Task cancel context.CancelFunc + done chan struct{} - connCh chan any - l logrus.FieldLogger + l logrus.FieldLogger mu sync.Mutex } @@ -61,7 +61,6 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { base := &StreamRoute{ StreamEntry: entry, url: url, - connCh: make(chan any, 100), } if entry.Scheme.ListeningScheme.IsTCP() { base.StreamImpl = NewTCPRoute(base) @@ -73,7 +72,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { } func (r *StreamRoute) String() string { - return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias) + return fmt.Sprintf("stream %s", r.Alias) } func (r *StreamRoute) URL() url.URL { @@ -88,14 +87,12 @@ func (r *StreamRoute) Start() E.NestedError { return nil } r.task, r.cancel = common.NewTaskWithCancel(r.String()) - r.wg.Wait() if err := r.Setup(); err != nil { return E.FailWith("setup", err) } + r.done = make(chan struct{}) r.l.Infof("listening on port %d", r.Port.ListeningPort) - r.wg.Add(2) go r.acceptConnections() - go r.handleConnections() if !r.Healthcheck.Disabled { r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck) r.HealthMon.Start() @@ -122,11 +119,7 @@ func (r *StreamRoute) Stop() E.NestedError { r.cancel() r.CloseListeners() - r.wg.Wait() - r.task.Finished() - - r.task, r.cancel = nil, nil - + <-r.done return nil } @@ -135,41 +128,45 @@ func (r *StreamRoute) Started() bool { } func (r *StreamRoute) acceptConnections() { - defer r.wg.Done() + var connWg sync.WaitGroup + + task := r.task.Subtask("%s accept connections", r.String()) + + defer func() { + connWg.Wait() + task.Finished() + r.task.Finished() + r.task, r.cancel = nil, nil + close(r.done) + r.done = nil + }() for { select { - case <-r.task.Context().Done(): + case <-task.Context().Done(): return default: conn, err := r.Accept() if err != nil { select { - case <-r.task.Context().Done(): + case <-task.Context().Done(): return default: - r.l.Error(err) + var nErr *net.OpError + ok := errors.As(err, &nErr) + if !(ok && nErr.Timeout()) { + r.l.Error(err) + } continue } } - r.connCh <- conn - } - } -} - -func (r *StreamRoute) handleConnections() { - defer r.wg.Done() - - for { - select { - case <-r.task.Context().Done(): - return - case conn := <-r.connCh: + connWg.Add(1) go func() { err := r.Handle(conn) if err != nil && !errors.Is(err, context.Canceled) { r.l.Error(err) } + connWg.Done() }() } } diff --git a/internal/route/tcp.go b/internal/route/tcp.go index b7367e1..e076a76 100755 --- a/internal/route/tcp.go +++ b/internal/route/tcp.go @@ -4,31 +4,25 @@ import ( "context" "fmt" "net" - "sync" "time" T "github.com/yusing/go-proxy/internal/proxy/fields" U "github.com/yusing/go-proxy/internal/utils" + F "github.com/yusing/go-proxy/internal/utils/functional" ) const tcpDialTimeout = 5 * time.Second type ( - Pipes []U.BidirectionalPipe - - TCPRoute struct { + TCPConnMap = F.Map[net.Conn, struct{}] + TCPRoute struct { *StreamRoute - listener net.Listener - pipe Pipes - mu sync.Mutex + listener *net.TCPListener } ) func NewTCPRoute(base *StreamRoute) StreamImpl { - return &TCPRoute{ - StreamRoute: base, - pipe: make(Pipes, 0), - } + return &TCPRoute{StreamRoute: base} } func (route *TCPRoute) Setup() error { @@ -38,11 +32,12 @@ func (route *TCPRoute) Setup() error { } //! this read the allocated port from original ':0' route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) - route.listener = in + route.listener = in.(*net.TCPListener) return nil } func (route *TCPRoute) Accept() (any, error) { + route.listener.SetDeadline(time.Now().Add(time.Second)) return route.listener.Accept() } @@ -50,24 +45,23 @@ func (route *TCPRoute) Handle(c any) error { clientConn := c.(net.Conn) defer clientConn.Close() + go func() { + <-route.task.Context().Done() + clientConn.Close() + }() ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) - defer cancel() serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) dialer := &net.Dialer{} serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) + cancel() if err != nil { return err } - route.mu.Lock() - pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) - route.pipe = append(route.pipe, pipe) - - route.mu.Unlock() return pipe.Start() } diff --git a/internal/route/udp.go b/internal/route/udp.go index c15630c..faadff6 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "time" T "github.com/yusing/go-proxy/internal/proxy/fields" U "github.com/yusing/go-proxy/internal/utils" @@ -67,6 +68,7 @@ func (route *UDPRoute) Accept() (any, error) { in := route.listeningConn buffer := make([]byte, udpBufferSize) + route.listeningConn.SetReadDeadline(time.Now().Add(time.Second)) nRead, srcAddr, err := in.ReadFromUDP(buffer) if err != nil { return nil, err diff --git a/internal/server/server.go b/internal/server/server.go index 44e349b..7bf2120 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -84,7 +84,7 @@ func NewServer(opt Options) (s *Server) { CertProvider: opt.CertProvider, http: httpSer, https: httpsSer, - task: common.GlobalTask("Server " + opt.Name), + task: common.GlobalTask(opt.Name + " server"), } } @@ -133,13 +133,11 @@ func (s *Server) stop() { if s.http != nil && s.httpStarted { s.handleErr("http", s.http.Shutdown(ctx)) s.httpStarted = false - logger.Debugf("HTTP server %q stopped", s.Name) } if s.https != nil && s.httpsStarted { s.handleErr("https", s.https.Shutdown(ctx)) s.httpsStarted = false - logger.Debugf("HTTPS server %q stopped", s.Name) } } @@ -152,7 +150,7 @@ func (s *Server) handleErr(scheme string, err error) { case err == nil, errors.Is(err, http.ErrServerClosed): return default: - logrus.Fatalf("failed to start %s %s server: %s", scheme, s.Name, err) + logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err) } } diff --git a/internal/types/raw_entry.go b/internal/types/raw_entry.go index 029c944..9d8a56d 100644 --- a/internal/types/raw_entry.go +++ b/internal/types/raw_entry.go @@ -22,9 +22,9 @@ type ( // raw entry object before validation // loaded from docker labels or yaml file Alias string `json:"-" yaml:"-"` - Scheme string `json:"scheme" yaml:"scheme"` - Host string `json:"host" yaml:"host"` - Port string `json:"port" yaml:"port"` + Scheme string `json:"scheme,omitempty" yaml:"scheme"` + Host string `json:"host,omitempty" yaml:"host"` + Port string `json:"port,omitempty" yaml:"port"` NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"` @@ -33,7 +33,7 @@ type ( Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` /* Docker only */ - *docker.Container `json:"container" yaml:"-"` + Container *docker.Container `json:"container,omitempty" yaml:"-"` } RawEntries = F.Map[string, *RawEntry] @@ -43,16 +43,17 @@ var NewProxyEntries = F.NewMapOf[string, *RawEntry] func (e *RawEntry) FillMissingFields() { isDocker := e.Container != nil + cont := e.Container if !isDocker { - e.Container = &docker.Container{} + cont = docker.DummyContainer } if e.Host == "" { switch { - case e.PrivateIP != "": - e.Host = e.PrivateIP - case e.PublicIP != "": - e.Host = e.PublicIP + case cont.PrivateIP != "": + e.Host = cont.PrivateIP + case cont.PublicIP != "": + e.Host = cont.PublicIP case !isDocker: e.Host = "localhost" } @@ -60,14 +61,14 @@ func (e *RawEntry) FillMissingFields() { lp, pp, extra := e.splitPorts() - if port, ok := common.ServiceNamePortMapTCP[e.ImageName]; ok { + if port, ok := common.ServiceNamePortMapTCP[cont.ImageName]; ok { if pp == "" { pp = strconv.Itoa(port) } if e.Scheme == "" { e.Scheme = "tcp" } - } else if port, ok := common.ImageNamePortMap[e.ImageName]; ok { + } else if port, ok := common.ImageNamePortMap[cont.ImageName]; ok { if pp == "" { pp = strconv.Itoa(port) } @@ -77,9 +78,9 @@ func (e *RawEntry) FillMissingFields() { } else if pp == "" && e.Scheme == "https" { pp = "443" } else if pp == "" { - if p := lowestPort(e.PrivatePortMapping); p != "" { + if p := lowestPort(cont.PrivatePortMapping); p != "" { pp = p - } else if p := lowestPort(e.PublicPortMapping); p != "" { + } else if p := lowestPort(cont.PublicPortMapping); p != "" { pp = p } else if !isDocker { pp = "80" @@ -89,23 +90,23 @@ func (e *RawEntry) FillMissingFields() { } // replace private port with public port if using public IP. - if e.Host == e.PublicIP { - if p, ok := e.PrivatePortMapping[pp]; ok { + if e.Host == cont.PublicIP { + if p, ok := cont.PrivatePortMapping[pp]; ok { pp = U.PortString(p.PublicPort) } } // replace public port with private port if using private IP. - if e.Host == e.PrivateIP { - if p, ok := e.PublicPortMapping[pp]; ok { + if e.Host == cont.PrivateIP { + if p, ok := cont.PublicPortMapping[pp]; ok { pp = U.PortString(p.PrivatePort) } } if e.Scheme == "" && isDocker { switch { - case e.Host == e.PublicIP && e.PublicPortMapping[pp].Type == "udp": + case e.Host == cont.PublicIP && cont.PublicPortMapping[pp].Type == "udp": e.Scheme = "udp" - case e.Host == e.PrivateIP && e.PrivatePortMapping[pp].Type == "udp": + case e.Host == cont.PrivateIP && cont.PrivatePortMapping[pp].Type == "udp": e.Scheme = "udp" } } @@ -127,17 +128,17 @@ func (e *RawEntry) FillMissingFields() { if e.HealthCheck.Timeout == 0 { e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault } - if e.IdleTimeout == "" { - e.IdleTimeout = common.IdleTimeoutDefault + if cont.IdleTimeout == "" { + cont.IdleTimeout = common.IdleTimeoutDefault } - if e.WakeTimeout == "" { - e.WakeTimeout = common.WakeTimeoutDefault + if cont.WakeTimeout == "" { + cont.WakeTimeout = common.WakeTimeoutDefault } - if e.StopTimeout == "" { - e.StopTimeout = common.StopTimeoutDefault + if cont.StopTimeout == "" { + cont.StopTimeout = common.StopTimeoutDefault } - if e.StopMethod == "" { - e.StopMethod = common.StopMethodDefault + if cont.StopMethod == "" { + cont.StopMethod = common.StopMethodDefault } e.Port = joinPorts(lp, pp, extra) diff --git a/internal/utils/io.go b/internal/utils/io.go index 406197b..18a8773 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -99,9 +99,53 @@ func (p BidirectionalPipe) Start() error { return b.Build().Error() } -func Copy(dst *ContextWriter, src *ContextReader) error { - _, err := io.Copy(dst, src) - return err +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// This is a copy of io.Copy with context handling +// Author: yusing +func Copy(dst *ContextWriter, src *ContextReader) (err error) { + size := 32 * 1024 + if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } + buf := make([]byte, size) + for { + select { + case <-src.ctx.Done(): + return src.ctx.Err() + case <-dst.ctx.Done(): + return dst.ctx.Err() + default: + nr, er := src.Reader.Read(buf) + if nr > 0 { + nw, ew := dst.Writer.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write result") + } + } + if ew != nil { + err = ew + return + } + if nr != nw { + err = io.ErrShortWrite + return + } + } + if er != nil { + if er != io.EOF { + err = er + } + return + } + } + } } func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index f35823a..9d2e8d6 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -99,12 +99,10 @@ func Serialize(data any) (SerializedObject, E.NestedError) { continue // Ignore this field if the tag is "-" } if strings.Contains(jsonTag, ",omitempty") { - if field.Type.Kind() == reflect.Ptr && value.Field(i).IsNil() { - continue - } if value.Field(i).IsZero() { continue } + jsonTag = strings.Replace(jsonTag, ",omitempty", "", 1) } // If the json tag is not empty, use it as the key diff --git a/internal/watcher/health/healthcheck_config.go b/internal/watcher/health/healthcheck_config.go index 31e0043..293caa9 100644 --- a/internal/watcher/health/healthcheck_config.go +++ b/internal/watcher/health/healthcheck_config.go @@ -7,7 +7,7 @@ import ( ) type HealthCheckConfig struct { - Disabled bool `json:"disabled" yaml:"disabled"` + Disabled bool `json:"disabled,omitempty" yaml:"disabled"` Path string `json:"path,omitempty" yaml:"path"` UseGet bool `json:"use_get,omitempty" yaml:"use_get"` Interval time.Duration `json:"interval" yaml:"interval"` diff --git a/internal/watcher/health/json.go b/internal/watcher/health/json.go new file mode 100644 index 0000000..cefe4c9 --- /dev/null +++ b/internal/watcher/health/json.go @@ -0,0 +1,30 @@ +package health + +import ( + "encoding/json" + "time" + + "github.com/yusing/go-proxy/internal/net/types" +) + +type JSONRepresentation struct { + Name string + Config *HealthCheckConfig + Status Status + Started time.Time + Uptime time.Duration + URL types.URL + Extra map[string]any +} + +func (jsonRepr *JSONRepresentation) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]any{ + "name": jsonRepr.Name, + "config": jsonRepr.Config, + "started": jsonRepr.Started.Unix(), + "status": jsonRepr.Status.String(), + "uptime": jsonRepr.Uptime.Seconds(), + "url": jsonRepr.URL.String(), + "extra": jsonRepr.Extra, + }) +} diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go index ab62dde..37f3e39 100644 --- a/internal/watcher/health/monitor.go +++ b/internal/watcher/health/monitor.go @@ -2,7 +2,6 @@ package health import ( "context" - "encoding/json" "errors" "sync" "time" @@ -25,8 +24,9 @@ type ( } HealthCheckFunc func() (healthy bool, detail string, err error) monitor struct { - config *HealthCheckConfig - url types.URL + service string + config *HealthCheckConfig + url types.URL status U.AtomicValue[Status] checkHealth HealthCheckFunc @@ -43,8 +43,10 @@ type ( var monMap = F.NewMapOf[string, HealthMonitor]() func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { - task, cancel := task.SubtaskWithCancel("Health monitor for %s", task.Name()) + service := task.Name() + task, cancel := task.SubtaskWithCancel("Health monitor for %s", service) mon := &monitor{ + service: service, config: config, url: url, checkHealth: healthCheckFunc, @@ -57,17 +59,12 @@ func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, heal return mon } -func Inspect(name string) (status Status, ok bool) { - mon, ok := monMap.Load(name) - if !ok { - return - } - return mon.Status(), true +func Inspect(name string) (HealthMonitor, bool) { + return monMap.Load(name) } func (mon *monitor) Start() { defer monMap.Store(mon.task.Name(), mon) - defer logger.Debugf("%s health monitor started", mon.String()) go func() { defer close(mon.done) @@ -93,12 +90,9 @@ func (mon *monitor) Start() { } } }() - logger.Debugf("health monitor %q started", mon.String()) } func (mon *monitor) Stop() { - defer logger.Debugf("%s health monitor stopped", mon.String()) - monMap.Delete(mon.task.Name()) mon.mu.Lock() @@ -132,14 +126,14 @@ func (mon *monitor) String() string { } func (mon *monitor) MarshalJSON() ([]byte, error) { - return json.Marshal(map[string]any{ - "name": mon.Name(), - "url": mon.url, - "status": mon.status.Load(), - "uptime": mon.Uptime().String(), - "started": mon.startTime.Unix(), - "config": mon.config, - }) + return (&JSONRepresentation{ + Name: mon.Name(), + Config: mon.config, + Status: mon.status.Load(), + Started: mon.startTime, + Uptime: mon.Uptime(), + URL: mon.url, + }).MarshalJSON() } func (mon *monitor) checkUpdateHealth() (hasError bool) { @@ -147,7 +141,7 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) { if err != nil { mon.status.Store(StatusError) if !errors.Is(err, context.Canceled) { - logger.Errorf("%s failed to check health: %s", mon.String(), err) + logger.Errorf("%s failed to check health: %s", mon.service, err) } mon.Stop() return false @@ -160,9 +154,9 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) { } if healthy != (mon.status.Swap(status) == StatusHealthy) { if healthy { - logger.Infof("%s is up", mon.String()) + logger.Infof("%s is up", mon.service) } else { - logger.Warnf("%s is down: %s", mon.String(), detail) + logger.Warnf("%s is down: %s", mon.service, detail) } } diff --git a/internal/watcher/health/status.go b/internal/watcher/health/status.go index 3c30c9d..9d4fd29 100644 --- a/internal/watcher/health/status.go +++ b/internal/watcher/health/status.go @@ -1,7 +1,5 @@ package health -import "encoding/json" - type Status int const ( @@ -36,7 +34,7 @@ func (s Status) String() string { } func (s Status) MarshalJSON() ([]byte, error) { - return json.Marshal(s.String()) + return []byte(`"` + s.String() + `"`), nil } func (s Status) Good() bool {