cleanup some loadbalancer code

This commit is contained in:
yusing 2025-01-19 04:32:50 +08:00
parent 589b3a7a13
commit b253dce7e1
8 changed files with 136 additions and 93 deletions

View file

@ -31,7 +31,7 @@ func (lb *LoadBalancer) newIPHash() impl {
return impl return impl
} }
func (impl *ipHash) OnAddServer(srv *Server) { func (impl *ipHash) OnAddServer(srv Server) {
impl.mu.Lock() impl.mu.Lock()
defer impl.mu.Unlock() defer impl.mu.Unlock()
@ -48,7 +48,7 @@ func (impl *ipHash) OnAddServer(srv *Server) {
impl.pool = append(impl.pool, srv) impl.pool = append(impl.pool, srv)
} }
func (impl *ipHash) OnRemoveServer(srv *Server) { func (impl *ipHash) OnRemoveServer(srv Server) {
impl.mu.Lock() impl.mu.Lock()
defer impl.mu.Unlock() defer impl.mu.Unlock()

View file

@ -9,21 +9,21 @@ import (
type leastConn struct { type leastConn struct {
*LoadBalancer *LoadBalancer
nConn F.Map[*Server, *atomic.Int64] nConn F.Map[Server, *atomic.Int64]
} }
func (lb *LoadBalancer) newLeastConn() impl { func (lb *LoadBalancer) newLeastConn() impl {
return &leastConn{ return &leastConn{
LoadBalancer: lb, LoadBalancer: lb,
nConn: F.NewMapOf[*Server, *atomic.Int64](), nConn: F.NewMapOf[Server, *atomic.Int64](),
} }
} }
func (impl *leastConn) OnAddServer(srv *Server) { func (impl *leastConn) OnAddServer(srv Server) {
impl.nConn.Store(srv, new(atomic.Int64)) impl.nConn.Store(srv, new(atomic.Int64))
} }
func (impl *leastConn) OnRemoveServer(srv *Server) { func (impl *leastConn) OnRemoveServer(srv Server) {
impl.nConn.Delete(srv) impl.nConn.Delete(srv)
} }
@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.R
srv := srvs[0] srv := srvs[0]
minConn, ok := impl.nConn.Load(srv) minConn, ok := impl.nConn.Load(srv)
if !ok { if !ok {
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name) impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
http.Error(rw, "Internal error", http.StatusInternalServerError) http.Error(rw, "Internal error", http.StatusInternalServerError)
} }
for i := 1; i < len(srvs); i++ { for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i]) nConn, ok := impl.nConn.Load(srvs[i])
if !ok { if !ok {
impl.l.Error().Msgf("[BUG] server %s not found", srv.Name) impl.l.Error().Msgf("[BUG] server %s not found", srv.Name())
http.Error(rw, "Internal error", http.StatusInternalServerError) http.Error(rw, "Internal error", http.StatusInternalServerError)
} }
if nConn.Load() < minConn.Load() { if nConn.Load() < minConn.Load() {

View file

@ -20,8 +20,8 @@ import (
type ( type (
impl interface { impl interface {
ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv *Server) OnAddServer(srv Server)
OnRemoveServer(srv *Server) OnRemoveServer(srv Server)
} }
LoadBalancer struct { LoadBalancer struct {
@ -61,7 +61,7 @@ func (lb *LoadBalancer) Start(parent task.Parent) E.Error {
}) })
lb.task.OnFinished("cleanup", func() { lb.task.OnFinished("cleanup", func() {
if lb.impl != nil { if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) { lb.pool.RangeAll(func(k string, v Server) {
lb.impl.OnRemoveServer(v) lb.impl.OnRemoveServer(v)
}) })
} }
@ -90,7 +90,7 @@ func (lb *LoadBalancer) updateImpl() {
default: // should happen in test only default: // should happen in test only
lb.impl = lb.newRoundRobin() lb.impl = lb.newRoundRobin()
} }
lb.pool.RangeAll(func(_ string, srv *Server) { lb.pool.RangeAll(func(_ string, srv Server) {
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
}) })
} }
@ -120,44 +120,44 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
} }
} }
func (lb *LoadBalancer) AddServer(srv *Server) { func (lb *LoadBalancer) AddServer(srv Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
if lb.pool.Has(srv.Name) { if lb.pool.Has(srv.Name()) {
old, _ := lb.pool.Load(srv.Name) old, _ := lb.pool.Load(srv.Name())
lb.sumWeight -= old.Weight lb.sumWeight -= old.Weight()
lb.impl.OnRemoveServer(old) lb.impl.OnRemoveServer(old)
} }
lb.pool.Store(srv.Name, srv) lb.pool.Store(srv.Name(), srv)
lb.sumWeight += srv.Weight lb.sumWeight += srv.Weight()
lb.rebalance() lb.rebalance()
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
lb.l.Debug(). lb.l.Debug().
Str("action", "add"). Str("action", "add").
Str("server", srv.Name). Str("server", srv.Name()).
Msgf("%d servers available", lb.pool.Size()) Msgf("%d servers available", lb.pool.Size())
} }
func (lb *LoadBalancer) RemoveServer(srv *Server) { func (lb *LoadBalancer) RemoveServer(srv Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
if !lb.pool.Has(srv.Name) { if !lb.pool.Has(srv.Name()) {
return return
} }
lb.pool.Delete(srv.Name) lb.pool.Delete(srv.Name())
lb.sumWeight -= srv.Weight lb.sumWeight -= srv.Weight()
lb.rebalance() lb.rebalance()
lb.impl.OnRemoveServer(srv) lb.impl.OnRemoveServer(srv)
lb.l.Debug(). lb.l.Debug().
Str("action", "remove"). Str("action", "remove").
Str("server", srv.Name). Str("server", srv.Name()).
Msgf("%d servers left", lb.pool.Size()) Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 { if lb.pool.Size() == 0 {
@ -178,13 +178,14 @@ func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == 0 { // distribute evenly if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / Weight(poolSize) weightEach := maxWeight / Weight(poolSize)
remainder := maxWeight % Weight(poolSize) remainder := maxWeight % Weight(poolSize)
lb.pool.RangeAll(func(_ string, s *Server) { lb.pool.RangeAll(func(_ string, s Server) {
s.Weight = weightEach w := weightEach
lb.sumWeight += weightEach lb.sumWeight += weightEach
if remainder > 0 { if remainder > 0 {
s.Weight++ w++
remainder-- remainder--
} }
s.SetWeight(w)
}) })
return return
} }
@ -193,25 +194,25 @@ func (lb *LoadBalancer) rebalance() {
scaleFactor := float64(maxWeight) / float64(lb.sumWeight) scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0 lb.sumWeight = 0
lb.pool.RangeAll(func(_ string, s *Server) { lb.pool.RangeAll(func(_ string, s Server) {
s.Weight = Weight(float64(s.Weight) * scaleFactor) s.SetWeight(Weight(float64(s.Weight()) * scaleFactor))
lb.sumWeight += s.Weight lb.sumWeight += s.Weight()
}) })
delta := maxWeight - lb.sumWeight delta := maxWeight - lb.sumWeight
if delta == 0 { if delta == 0 {
return return
} }
lb.pool.Range(func(_ string, s *Server) bool { lb.pool.Range(func(_ string, s Server) bool {
if delta == 0 { if delta == 0 {
return false return false
} }
if delta > 0 { if delta > 0 {
s.Weight++ s.SetWeight(s.Weight() + 1)
lb.sumWeight++ lb.sumWeight++
delta-- delta--
} else { } else {
s.Weight-- s.SetWeight(s.Weight() - 1)
lb.sumWeight-- lb.sumWeight--
delta++ delta++
} }
@ -229,22 +230,20 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// wake all servers // wake all servers
for _, srv := range srvs { for _, srv := range srvs {
if err := srv.TryWake(); err != nil { if err := srv.TryWake(); err != nil {
lb.l.Warn().Err(err).Str("server", srv.Name).Msg("failed to wake server") lb.l.Warn().Err(err).
Str("server", srv.Name()).
Msg("failed to wake server")
} }
} }
} }
lb.impl.ServeHTTP(srvs, rw, r) lb.impl.ServeHTTP(srvs, rw, r)
} }
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// MarshalJSON implements health.HealthMonitor. // MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any) extra := make(map[string]any)
lb.pool.RangeAll(func(k string, v *Server) { lb.pool.RangeAll(func(k string, v Server) {
extra[v.Name] = v.HealthMonitor() extra[v.Name()] = v
}) })
return (&monitor.JSONRepresentation{ return (&monitor.JSONRepresentation{
@ -269,20 +268,43 @@ func (lb *LoadBalancer) Status() health.Status {
if lb.pool.Size() == 0 { if lb.pool.Size() == 0 {
return health.StatusUnknown return health.StatusUnknown
} }
if len(lb.availServers()) == 0 {
isHealthy := true
lb.pool.Range(func(_ string, srv Server) bool {
if srv.Status().Bad() {
isHealthy = false
return false
}
return true
})
if !isHealthy {
return health.StatusUnhealthy return health.StatusUnhealthy
} }
return health.StatusHealthy return health.StatusHealthy
} }
// Uptime implements health.HealthMonitor.
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// Latency implements health.HealthMonitor.
func (lb *LoadBalancer) Latency() time.Duration {
var sum time.Duration
lb.pool.RangeAll(func(_ string, srv Server) {
sum += srv.Latency()
})
return sum
}
// String implements health.HealthMonitor. // String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string { func (lb *LoadBalancer) String() string {
return lb.Name() return lb.Name()
} }
func (lb *LoadBalancer) availServers() []*Server { func (lb *LoadBalancer) availServers() []Server {
avail := make([]*Server, 0, lb.pool.Size()) avail := make([]Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv *Server) { lb.pool.RangeAll(func(_ string, srv Server) {
if srv.Status().Good() { if srv.Status().Good() {
avail = append(avail, srv) avail = append(avail, srv)
} }

View file

@ -3,6 +3,7 @@ package loadbalancer
import ( import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types" loadbalance "github.com/yusing/go-proxy/internal/net/http/loadbalancer/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -12,31 +13,31 @@ func TestRebalance(t *testing.T) {
t.Run("zero", func(t *testing.T) { t.Run("zero", func(t *testing.T) {
lb := New(new(loadbalance.Config)) lb := New(new(loadbalance.Config))
for range 10 { for range 10 {
lb.AddServer(&Server{}) lb.AddServer(types.TestNewServer(0))
} }
lb.rebalance() lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight) ExpectEqual(t, lb.sumWeight, maxWeight)
}) })
t.Run("less", func(t *testing.T) { t.Run("less", func(t *testing.T) {
lb := New(new(loadbalance.Config)) lb := New(new(loadbalance.Config))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance() lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight) ExpectEqual(t, lb.sumWeight, maxWeight)
}) })
t.Run("more", func(t *testing.T) { t.Run("more", func(t *testing.T) {
lb := New(new(loadbalance.Config)) lb := New(new(loadbalance.Config))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .4)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .4))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .3)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .3))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .2)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .2))
lb.AddServer(&Server{Weight: loadbalance.Weight(float64(maxWeight) * .1)}) lb.AddServer(types.TestNewServer(float64(maxWeight) * .1))
lb.rebalance() lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight) ExpectEqual(t, lb.sumWeight, maxWeight)

View file

@ -10,8 +10,8 @@ type roundRobin struct {
} }
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} } func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv *Server) {} func (lb *roundRobin) OnAddServer(srv Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {} func (lb *roundRobin) OnRemoveServer(srv Server) {}
func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) { func (lb *roundRobin) ServeHTTP(srvs Servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1) % uint32(len(srvs)) index := lb.index.Add(1) % uint32(len(srvs))

View file

@ -6,7 +6,7 @@ import (
type ( type (
Server = types.Server Server = types.Server
Servers = types.Servers Servers = []types.Server
Pool = types.Pool Pool = types.Pool
Weight = types.Weight Weight = types.Weight
Config = types.Config Config = types.Config

View file

@ -2,7 +2,6 @@ package types
import ( import (
"net/http" "net/http"
"time"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
@ -12,51 +11,72 @@ import (
) )
type ( type (
Server struct { server struct {
_ U.NoCopy _ U.NoCopy
Name string name string
URL types.URL url types.URL
Weight Weight weight Weight
handler http.Handler http.Handler `json:"-"`
healthMon health.HealthMonitor health.HealthMonitor
} }
Servers = []*Server
Pool = F.Map[string, *Server] Server interface {
http.Handler
health.HealthMonitor
Name() string
URL() types.URL
Weight() Weight
SetWeight(Weight)
TryWake() error
}
Pool = F.Map[string, Server]
) )
var NewServerPool = F.NewMap[Pool] var NewServerPool = F.NewMap[Pool]
func NewServer(name string, url types.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) *Server { func NewServer(name string, url types.URL, weight Weight, handler http.Handler, healthMon health.HealthMonitor) Server {
srv := &Server{ srv := &server{
Name: name, name: name,
URL: url, url: url,
Weight: weight, weight: weight,
handler: handler, Handler: handler,
healthMon: healthMon, HealthMonitor: healthMon,
} }
return srv return srv
} }
func (srv *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func TestNewServer[T ~int | ~float32 | ~float64](weight T) Server {
srv.handler.ServeHTTP(rw, r) srv := &server{
weight: Weight(weight),
}
return srv
} }
func (srv *Server) String() string { func (srv *server) Name() string {
return srv.Name return srv.name
} }
func (srv *Server) Status() health.Status { func (srv *server) URL() types.URL {
return srv.healthMon.Status() return srv.url
} }
func (srv *Server) Uptime() time.Duration { func (srv *server) Weight() Weight {
return srv.healthMon.Uptime() return srv.weight
} }
func (srv *Server) TryWake() error { func (srv *server) SetWeight(weight Weight) {
waker, ok := srv.handler.(idlewatcher.Waker) srv.weight = weight
}
func (srv *server) String() string {
return srv.name
}
func (srv *server) TryWake() error {
waker, ok := srv.Handler.(idlewatcher.Waker)
if ok { if ok {
if err := waker.Wake(); err != nil { if err := waker.Wake(); err != nil {
return err return err
@ -64,7 +84,3 @@ func (srv *Server) TryWake() error {
} }
return nil return nil
} }
func (srv *Server) HealthMonitor() health.HealthMonitor {
return srv.healthMon
}

View file

@ -15,13 +15,17 @@ type (
Detail string Detail string
Latency time.Duration Latency time.Duration
} }
WithHealthInfo interface {
Status() Status
Uptime() time.Duration
Latency() time.Duration
}
HealthMonitor interface { HealthMonitor interface {
task.TaskStarter task.TaskStarter
task.TaskFinisher task.TaskFinisher
fmt.Stringer fmt.Stringer
json.Marshaler json.Marshaler
Status() Status WithHealthInfo
Uptime() time.Duration
Name() string Name() string
} }
HealthChecker interface { HealthChecker interface {