From b253dce7e19a74623ef3b3e1d40e9ea4f591df2d Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 19 Jan 2025 04:32:50 +0800 Subject: [PATCH] cleanup some loadbalancer code --- internal/net/http/loadbalancer/ip_hash.go | 4 +- internal/net/http/loadbalancer/least_conn.go | 12 +-- .../net/http/loadbalancer/loadbalancer.go | 94 ++++++++++++------- .../http/loadbalancer/loadbalancer_test.go | 27 +++--- internal/net/http/loadbalancer/round_robin.go | 6 +- internal/net/http/loadbalancer/types.go | 2 +- .../net/http/loadbalancer/types/server.go | 76 +++++++++------ internal/watcher/health/health_checker.go | 8 +- 8 files changed, 136 insertions(+), 93 deletions(-) diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index cbb6ab0..384f7cf 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -31,7 +31,7 @@ func (lb *LoadBalancer) newIPHash() impl { return impl } -func (impl *ipHash) OnAddServer(srv *Server) { +func (impl *ipHash) OnAddServer(srv Server) { impl.mu.Lock() defer impl.mu.Unlock() @@ -48,7 +48,7 @@ func (impl *ipHash) OnAddServer(srv *Server) { impl.pool = append(impl.pool, srv) } -func (impl *ipHash) OnRemoveServer(srv *Server) { +func (impl *ipHash) OnRemoveServer(srv Server) { impl.mu.Lock() defer impl.mu.Unlock() diff --git a/internal/net/http/loadbalancer/least_conn.go b/internal/net/http/loadbalancer/least_conn.go index 3363915..7130e42 100644 --- a/internal/net/http/loadbalancer/least_conn.go +++ b/internal/net/http/loadbalancer/least_conn.go @@ -9,21 +9,21 @@ import ( type leastConn struct { *LoadBalancer - nConn F.Map[*Server, *atomic.Int64] + nConn F.Map[Server, *atomic.Int64] } func (lb *LoadBalancer) newLeastConn() impl { return &leastConn{ LoadBalancer: lb, - nConn: F.NewMapOf[*Server, *atomic.Int64](), + nConn: F.NewMapOf[Server, *atomic.Int64](), } } -func (impl *leastConn) OnAddServer(srv *Server) { +func (impl *leastConn) OnAddServer(srv Server) { impl.nConn.Store(srv, new(atomic.Int64)) } -func (impl *leastConn) OnRemoveServer(srv *Server) { +func (impl *leastConn) OnRemoveServer(srv Server) { impl.nConn.Delete(srv) } @@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.R srv := srvs[0] minConn, ok := impl.nConn.Load(srv) if !ok { - impl.l.Error().Msgf("[BUG] server %s not found", srv.Name) + impl.l.Error().Msgf("[BUG] server %s not found", srv.Name()) http.Error(rw, "Internal error", http.StatusInternalServerError) } for i := 1; i < len(srvs); i++ { nConn, ok := impl.nConn.Load(srvs[i]) if !ok { - impl.l.Error().Msgf("[BUG] server %s not found", srv.Name) + impl.l.Error().Msgf("[BUG] server %s not found", srv.Name()) http.Error(rw, "Internal error", http.StatusInternalServerError) } if nConn.Load() < minConn.Load() { diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index bea55b6..c4b8d71 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -20,8 +20,8 @@ import ( type ( impl interface { ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) - OnAddServer(srv *Server) - OnRemoveServer(srv *Server) + OnAddServer(srv Server) + OnRemoveServer(srv Server) } LoadBalancer struct { @@ -61,7 +61,7 @@ func (lb *LoadBalancer) Start(parent task.Parent) E.Error { }) lb.task.OnFinished("cleanup", func() { if lb.impl != nil { - lb.pool.RangeAll(func(k string, v *Server) { + lb.pool.RangeAll(func(k string, v Server) { lb.impl.OnRemoveServer(v) }) } @@ -90,7 +90,7 @@ func (lb *LoadBalancer) updateImpl() { default: // should happen in test only lb.impl = lb.newRoundRobin() } - lb.pool.RangeAll(func(_ string, srv *Server) { + lb.pool.RangeAll(func(_ string, srv Server) { lb.impl.OnAddServer(srv) }) } @@ -120,44 +120,44 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { } } -func (lb *LoadBalancer) AddServer(srv *Server) { +func (lb *LoadBalancer) AddServer(srv Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() - if lb.pool.Has(srv.Name) { - old, _ := lb.pool.Load(srv.Name) - lb.sumWeight -= old.Weight + if lb.pool.Has(srv.Name()) { + old, _ := lb.pool.Load(srv.Name()) + lb.sumWeight -= old.Weight() lb.impl.OnRemoveServer(old) } - lb.pool.Store(srv.Name, srv) - lb.sumWeight += srv.Weight + lb.pool.Store(srv.Name(), srv) + lb.sumWeight += srv.Weight() lb.rebalance() lb.impl.OnAddServer(srv) lb.l.Debug(). Str("action", "add"). - Str("server", srv.Name). + Str("server", srv.Name()). Msgf("%d servers available", lb.pool.Size()) } -func (lb *LoadBalancer) RemoveServer(srv *Server) { +func (lb *LoadBalancer) RemoveServer(srv Server) { lb.poolMu.Lock() defer lb.poolMu.Unlock() - if !lb.pool.Has(srv.Name) { + if !lb.pool.Has(srv.Name()) { return } - lb.pool.Delete(srv.Name) + lb.pool.Delete(srv.Name()) - lb.sumWeight -= srv.Weight + lb.sumWeight -= srv.Weight() lb.rebalance() lb.impl.OnRemoveServer(srv) lb.l.Debug(). Str("action", "remove"). - Str("server", srv.Name). + Str("server", srv.Name()). Msgf("%d servers left", lb.pool.Size()) if lb.pool.Size() == 0 { @@ -178,13 +178,14 @@ func (lb *LoadBalancer) rebalance() { if lb.sumWeight == 0 { // distribute evenly weightEach := maxWeight / Weight(poolSize) remainder := maxWeight % Weight(poolSize) - lb.pool.RangeAll(func(_ string, s *Server) { - s.Weight = weightEach + lb.pool.RangeAll(func(_ string, s Server) { + w := weightEach lb.sumWeight += weightEach if remainder > 0 { - s.Weight++ + w++ remainder-- } + s.SetWeight(w) }) return } @@ -193,25 +194,25 @@ func (lb *LoadBalancer) rebalance() { scaleFactor := float64(maxWeight) / float64(lb.sumWeight) lb.sumWeight = 0 - lb.pool.RangeAll(func(_ string, s *Server) { - s.Weight = Weight(float64(s.Weight) * scaleFactor) - lb.sumWeight += s.Weight + lb.pool.RangeAll(func(_ string, s Server) { + s.SetWeight(Weight(float64(s.Weight()) * scaleFactor)) + lb.sumWeight += s.Weight() }) delta := maxWeight - lb.sumWeight if delta == 0 { return } - lb.pool.Range(func(_ string, s *Server) bool { + lb.pool.Range(func(_ string, s Server) bool { if delta == 0 { return false } if delta > 0 { - s.Weight++ + s.SetWeight(s.Weight() + 1) lb.sumWeight++ delta-- } else { - s.Weight-- + s.SetWeight(s.Weight() - 1) lb.sumWeight-- delta++ } @@ -229,22 +230,20 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // wake all servers for _, srv := range srvs { if err := srv.TryWake(); err != nil { - lb.l.Warn().Err(err).Str("server", srv.Name).Msg("failed to wake server") + lb.l.Warn().Err(err). + Str("server", srv.Name()). + Msg("failed to wake server") } } } lb.impl.ServeHTTP(srvs, rw, r) } -func (lb *LoadBalancer) Uptime() time.Duration { - return time.Since(lb.startTime) -} - // MarshalJSON implements health.HealthMonitor. func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { extra := make(map[string]any) - lb.pool.RangeAll(func(k string, v *Server) { - extra[v.Name] = v.HealthMonitor() + lb.pool.RangeAll(func(k string, v Server) { + extra[v.Name()] = v }) return (&monitor.JSONRepresentation{ @@ -269,20 +268,43 @@ func (lb *LoadBalancer) Status() health.Status { if lb.pool.Size() == 0 { return health.StatusUnknown } - if len(lb.availServers()) == 0 { + + isHealthy := true + lb.pool.Range(func(_ string, srv Server) bool { + if srv.Status().Bad() { + isHealthy = false + return false + } + return true + }) + if !isHealthy { return health.StatusUnhealthy } return health.StatusHealthy } +// Uptime implements health.HealthMonitor. +func (lb *LoadBalancer) Uptime() time.Duration { + return time.Since(lb.startTime) +} + +// Latency implements health.HealthMonitor. +func (lb *LoadBalancer) Latency() time.Duration { + var sum time.Duration + lb.pool.RangeAll(func(_ string, srv Server) { + sum += srv.Latency() + }) + return sum +} + // String implements health.HealthMonitor. func (lb *LoadBalancer) String() string { return lb.Name() } -func (lb *LoadBalancer) availServers() []*Server { - avail := make([]*Server, 0, lb.pool.Size()) - lb.pool.RangeAll(func(_ string, srv *Server) { +func (lb *LoadBalancer) availServers() []Server { + avail := make([]Server, 0, lb.pool.Size()) + lb.pool.RangeAll(func(_ string, srv Server) { if srv.Status().Good() { avail = append(avail, srv) } diff --git a/internal/net/http/loadbalancer/loadbalancer_test.go b/internal/net/http/loadbalancer/loadbalancer_test.go index 234130c..349565e 100644 --- a/internal/net/http/loadbalancer/loadbalancer_test.go +++ b/internal/net/http/loadbalancer/loadbalancer_test.go @@ -3,6 +3,7 @@ package loadbalancer import ( "testing" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -12,31 +13,31 @@ func TestRebalance(t *testing.T) { t.Run("zero", func(t *testing.T) { lb := New(new(loadbalance.Config)) for range 10 { - lb.AddServer(&Server{}) + lb.AddServer(types.TestNewServer(0)) } lb.rebalance() ExpectEqual(t, lb.sumWeight, maxWeight) }) t.Run("less", func(t *testing.T) { lb := New(new(loadbalance.Config)) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .1)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .2)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .3)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .2)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .1)) lb.rebalance() // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) ExpectEqual(t, lb.sumWeight, maxWeight) }) t.Run("more", func(t *testing.T) { lb := New(new(loadbalance.Config)) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .4)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) - lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .1)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .2)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .3)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .4)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .3)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .2)) + lb.AddServer(types.TestNewServer(float64(maxWeight) * .1)) lb.rebalance() // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) ExpectEqual(t, lb.sumWeight, maxWeight) diff --git a/internal/net/http/loadbalancer/round_robin.go b/internal/net/http/loadbalancer/round_robin.go index 494c21e..09d6770 100644 --- a/internal/net/http/loadbalancer/round_robin.go +++ b/internal/net/http/loadbalancer/round_robin.go @@ -9,9 +9,9 @@ type roundRobin struct { index atomic.Uint32 } -func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} } -func (lb *roundRobin) OnAddServer(srv *Server) {} -func (lb *roundRobin) OnRemoveServer(srv *Server) {} +func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} } +func (lb *roundRobin) OnAddServer(srv Server) {} +func (lb *roundRobin) OnRemoveServer(srv Server) {} func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) { index := lb.index.Add(1) % uint32(len(srvs)) diff --git a/internal/net/http/loadbalancer/types.go b/internal/net/http/loadbalancer/types.go index aa60369..36b45ad 100644 --- a/internal/net/http/loadbalancer/types.go +++ b/internal/net/http/loadbalancer/types.go @@ -6,7 +6,7 @@ import ( type ( Server = types.Server - Servers = types.Servers + Servers = []types.Server Pool = types.Pool Weight = types.Weight Config = types.Config diff --git a/internal/net/http/loadbalancer/types/server.go b/internal/net/http/loadbalancer/types/server.go index c6b9537..b3df394 100644 --- a/internal/net/http/loadbalancer/types/server.go +++ b/internal/net/http/loadbalancer/types/server.go @@ -2,7 +2,6 @@ package types import ( "net/http" - "time" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" "github.com/yusing/go-proxy/internal/net/types" @@ -12,51 +11,72 @@ import ( ) type ( - Server struct { + server struct { _ U.NoCopy - Name string - URL types.URL - Weight Weight + name string + url types.URL + weight Weight - handler http.Handler - healthMon health.HealthMonitor + http.Handler `json:"-"` + health.HealthMonitor } - Servers = []*Server - Pool = F.Map[string, *Server] + + Server interface { + http.Handler + health.HealthMonitor + Name() string + URL() types.URL + Weight() Weight + SetWeight(Weight) + TryWake() error + } + + Pool = F.Map[string, Server] ) var NewServerPool = F.NewMap[Pool] -func NewServer(name string, url types.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) *Server { - srv := &Server{ - Name: name, - URL: url, - Weight: weight, - handler: handler, - healthMon: healthMon, +func NewServer(name string, url types.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server { + srv := &server{ + name: name, + url: url, + weight: weight, + Handler: handler, + HealthMonitor: healthMon, } return srv } -func (srv *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - srv.handler.ServeHTTP(rw, r) +func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server { + srv := &server{ + weight: Weight(weight), + } + return srv } -func (srv *Server) String() string { - return srv.Name +func (srv *server) Name() string { + return srv.name } -func (srv *Server) Status() health.Status { - return srv.healthMon.Status() +func (srv *server) URL() types.URL { + return srv.url } -func (srv *Server) Uptime() time.Duration { - return srv.healthMon.Uptime() +func (srv *server) Weight() Weight { + return srv.weight } -func (srv *Server) TryWake() error { - waker, ok := srv.handler.(idlewatcher.Waker) +func (srv *server) SetWeight(weight Weight) { + srv.weight = weight +} + +func (srv *server) String() string { + return srv.name +} + +func (srv *server) TryWake() error { + waker, ok := srv.Handler.(idlewatcher.Waker) if ok { if err := waker.Wake(); err != nil { return err @@ -64,7 +84,3 @@ func (srv *Server) TryWake() error { } return nil } - -func (srv *Server) HealthMonitor() health.HealthMonitor { - return srv.healthMon -} diff --git a/internal/watcher/health/health_checker.go b/internal/watcher/health/health_checker.go index d886616..0bc0414 100644 --- a/internal/watcher/health/health_checker.go +++ b/internal/watcher/health/health_checker.go @@ -15,13 +15,17 @@ type ( Detail string Latency time.Duration } + WithHealthInfo interface { + Status() Status + Uptime() time.Duration + Latency() time.Duration + } HealthMonitor interface { task.TaskStarter task.TaskFinisher fmt.Stringer json.Marshaler - Status() Status - Uptime() time.Duration + WithHealthInfo Name() string } HealthChecker interface {