diff --git a/cmd/main.go b/cmd/main.go index 66927ec..413ca53 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -11,7 +11,6 @@ import ( "reflect" "runtime" "strings" - "sync" "syscall" "time" @@ -137,13 +136,11 @@ func main() { signal.Notify(sig, syscall.SIGHUP) autocert := cfg.GetAutoCertProvider() - if autocert != nil { ctx, cancel := context.WithCancel(context.Background()) + onShutdown.Add(cancel) if err := autocert.Setup(ctx); err != nil { l.Fatal(err) - } else { - onShutdown.Add(cancel) } } else { l.Info("autocert not configured") @@ -179,19 +176,15 @@ func main() { // grafully shutdown logrus.Info("shutting down") done := make(chan struct{}, 1) + currentIdx := 0 - var wg sync.WaitGroup - wg.Add(onShutdown.Size()) - onShutdown.ForEach(func(f func()) { - go func() { + go func() { + onShutdown.ForEach(func(f func()) { l.Debugf("waiting for %s to complete...", funcName(f)) f() + currentIdx++ l.Debugf("%s done", funcName(f)) - wg.Done() - }() - }) - go func() { - wg.Wait() + }) close(done) }() @@ -201,9 +194,9 @@ func main() { logrus.Info("shutdown complete") case <-timeout: logrus.Info("timeout waiting for shutdown") - onShutdown.ForEach(func(f func()) { - l.Warnf("%s() is still running", funcName(f)) - }) + for i := currentIdx; i < onShutdown.Size(); i++ { + l.Warnf("%s() is still running", funcName(onShutdown.Get(i))) + } } } diff --git a/internal/config/config.go b/internal/config/config.go index 370ad25..071a524 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -220,7 +220,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { errors := E.NewBuilder("errors in %s these providers", action) - cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { + cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) { if err := do(p); err.HasError() { errors.Add(err.Subject(p)) } diff --git a/internal/docker/client.go b/internal/docker/client.go index 4132857..8d2845f 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -133,7 +133,7 @@ func ConnectClient(host string) (Client, E.NestedError) { } func CloseAllClients() { - clientMap.RangeAll(func(_ string, c Client) { + clientMap.RangeAllParallel(func(_ string, c Client) { c.Client.Close() }) clientMap.Clear() diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index 4de47db..466083f 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -46,7 +46,7 @@ func New(cfg Config) *LoadBalancer { lb := &LoadBalancer{Config: cfg, pool: servers{}} mode := cfg.Mode if !cfg.Mode.ValidateUpdate() { - logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode) + logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %s", cfg.Link, mode, cfg.Mode) } switch mode { case RoundRobin: @@ -69,6 +69,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) { lb.sumWeight += srv.Weight lb.impl.OnAddServer(srv) + logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool)) } func (lb *LoadBalancer) RemoveServer(srv *Server) { @@ -85,7 +86,11 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { } if lb.IsEmpty() { lb.Stop() + return } + + lb.Rebalance() + logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool)) } func (lb *LoadBalancer) IsEmpty() bool { @@ -98,14 +103,14 @@ func (lb *LoadBalancer) Rebalance() { } if lb.sumWeight == 0 { // distribute evenly weightEach := maxWeight / weightType(len(lb.pool)) - remainer := maxWeight % weightType(len(lb.pool)) + remainder := maxWeight % weightType(len(lb.pool)) for _, s := range lb.pool { s.Weight = weightEach lb.sumWeight += weightEach - if remainer > 0 { + if remainder > 0 { s.Weight++ + remainder-- } - remainer-- } return } @@ -149,17 +154,13 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } func (lb *LoadBalancer) Start() { - if lb.IsEmpty() { - return - } - if lb.sumWeight != 0 && lb.sumWeight != maxWeight { msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight) for _, s := range lb.pool { msg.Addf("%s: %d", s.Name, s.Weight) } lb.Rebalance() - inner := E.NewBuilder("After rebalancing") + inner := E.NewBuilder("after rebalancing") for _, s := range lb.pool { inner.Addf("%s: %d", s.Name, s.Weight) } @@ -168,22 +169,16 @@ func (lb *LoadBalancer) Start() { } if lb.sumWeight != 0 { - log.Warnf("Weighted mode not supported yet") - } - - switch lb.Mode { - case RoundRobin: - lb.impl = lb.newRoundRobin() - case LeastConn: - lb.impl = lb.newLeastConn() - case IPHash: - lb.impl = lb.newIPHash() + log.Warnf("weighted mode not supported yet") } lb.done = make(chan struct{}, 1) lb.ctx, lb.cancel = context.WithCancel(context.Background()) updateAll := func() { + lb.poolMu.Lock() + defer lb.poolMu.Unlock() + var wg sync.WaitGroup wg.Add(len(lb.pool)) for _, s := range lb.pool { @@ -195,6 +190,8 @@ func (lb *LoadBalancer) Start() { wg.Wait() } + logger.Debugf("loadbalancer %s started", lb.Link) + go func() { defer lb.cancel() defer close(lb.done) @@ -208,16 +205,14 @@ func (lb *LoadBalancer) Start() { case <-lb.ctx.Done(): return case <-ticker.C: - lb.poolMu.RLock() updateAll() - lb.poolMu.RUnlock() } } }() } func (lb *LoadBalancer) Stop() { - if lb.impl == nil { + if lb.cancel == nil { return } @@ -225,6 +220,8 @@ func (lb *LoadBalancer) Stop() { <-lb.done lb.pool = nil + + logger.Debugf("loadbalancer %s stopped", lb.Link) } func (lb *LoadBalancer) availServers() servers { diff --git a/internal/net/http/loadbalancer/server.go b/internal/net/http/loadbalancer/server.go index 798693b..52fff16 100644 --- a/internal/net/http/loadbalancer/server.go +++ b/internal/net/http/loadbalancer/server.go @@ -26,7 +26,7 @@ func NewServer(name string, url types.URL, weight weightType, handler http.Handl srv := &Server{ Name: name, URL: url, - Weight: weightType(weight), + Weight: weight, handler: handler, pinger: &http.Client{Timeout: 3 * time.Second}, } diff --git a/internal/proxy/provider/file.go b/internal/proxy/provider/file.go index 289641c..e3df8bd 100644 --- a/internal/proxy/provider/file.go +++ b/internal/proxy/provider/file.go @@ -52,12 +52,12 @@ func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) return } - routes.RangeAll(func(_ string, v R.Route) { + routes.RangeAllParallel(func(_ string, v R.Route) { b.Add(v.Stop()) }) routes.Clear() - newRoutes.RangeAll(func(_ string, v R.Route) { + newRoutes.RangeAllParallel(func(_ string, v R.Route) { b.Add(v.Start()) }) diff --git a/internal/proxy/provider/provider.go b/internal/proxy/provider/provider.go index b17044c..322789b 100644 --- a/internal/proxy/provider/provider.go +++ b/internal/proxy/provider/provider.go @@ -114,7 +114,7 @@ func (p *Provider) StartAllRoutes() (res E.NestedError) { nStarted := 0 nFailed := 0 - p.routes.RangeAll(func(alias string, r R.Route) { + p.routes.RangeAllParallel(func(alias string, r R.Route) { if err := r.Start(); err.HasError() { errors.Add(err.Subject(r)) nFailed++ @@ -138,7 +138,7 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) { nStopped := 0 nFailed := 0 - p.routes.RangeAll(func(alias string, r R.Route) { + p.routes.RangeAllParallel(func(alias string, r R.Route) { if err := r.Stop(); err.HasError() { errors.Add(err.Subject(r)) nFailed++ diff --git a/internal/route/http.go b/internal/route/http.go index 61d0378..fff7dd3 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -22,7 +22,7 @@ import ( type ( HTTPRoute struct { *P.ReverseProxyEntry - LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"` + LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer"` server *loadbalancer.Server handler http.Handler @@ -134,7 +134,8 @@ func (r *HTTPRoute) Start() E.NestedError { } httpRoutes.Store(string(r.LoadBalance.Link), linked) } - lb.AddServer(loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler)) + r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler) + lb.AddServer(r.server) return nil } @@ -150,8 +151,6 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) { waker.Unregister() } - r.handler = nil - if r.server != nil { linked, ok := httpRoutes.Load(string(r.LoadBalance.Link)) if ok { @@ -160,9 +159,13 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) { if linked.LoadBalancer.IsEmpty() { httpRoutes.Delete(string(r.LoadBalance.Link)) } + r.server = nil } else { httpRoutes.Delete(string(r.Alias)) } + + r.handler = nil + return } diff --git a/internal/route/udp.go b/internal/route/udp.go index 83227b4..cecc16d 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -108,7 +108,7 @@ func (route *UDPRoute) CloseListeners() { route.listeningConn.Close() route.listeningConn = nil } - route.connMap.RangeAll(func(_ string, conn *UDPConn) { + route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) { if err := conn.src.Close(); err != nil { route.l.Errorf("error closing src conn: %s", err) } diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index e0d3284..b578e5b 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -1,6 +1,8 @@ package functional import ( + "sync" + "github.com/puzpuzpuz/xsync/v3" "gopkg.in/yaml.v3" @@ -75,6 +77,20 @@ func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) { }) } +func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { + var wg sync.WaitGroup + wg.Add(m.Size()) + + m.Range(func(k KT, v VT) bool { + go func() { + do(k, v) + wg.Done() + }() + return true + }) + wg.Wait() +} + func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) { m.Range(func(k KT, v VT) bool { if criteria(v) { diff --git a/internal/utils/functional/slice.go b/internal/utils/functional/slice.go index d992aae..035e10e 100644 --- a/internal/utils/functional/slice.go +++ b/internal/utils/functional/slice.go @@ -38,6 +38,10 @@ func (s *Slice[T]) Iterator() []T { return s.s } +func (s *Slice[T]) Get(i int) T { + return s.s[i] +} + func (s *Slice[T]) Set(i int, v T) { s.s[i] = v } @@ -76,6 +80,20 @@ func (s *Slice[T]) SafePop() T { return s.Pop() } +func (s *Slice[T]) Remove(criteria func(T) bool) { + for i, v2 := range s.s { + if criteria(v2) { + s.s = append(s.s[:i], s.s[i+1:]...) + } + } +} + +func (s *Slice[T]) SafeRemove(criteria func(T) bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.Remove(criteria) +} + func (s *Slice[T]) ForEach(do func(T)) { for _, v := range s.s { do(v)