mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
287 lines
6.2 KiB
Go
287 lines
6.2 KiB
Go
package loadbalancer
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/yusing/go-proxy/internal/common"
|
|
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
|
E "github.com/yusing/go-proxy/internal/error"
|
|
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
|
"github.com/yusing/go-proxy/internal/task"
|
|
"github.com/yusing/go-proxy/internal/watcher/health"
|
|
)
|
|
|
|
// TODO: stats of each server.
|
|
// TODO: support weighted mode.
|
|
type (
|
|
impl interface {
|
|
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
|
OnAddServer(srv *Server)
|
|
OnRemoveServer(srv *Server)
|
|
}
|
|
Config struct {
|
|
Link string `json:"link" yaml:"link"`
|
|
Mode Mode `json:"mode" yaml:"mode"`
|
|
Weight weightType `json:"weight" yaml:"weight"`
|
|
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
|
|
}
|
|
LoadBalancer struct {
|
|
impl
|
|
*Config
|
|
|
|
task task.Task
|
|
|
|
pool Pool
|
|
poolMu sync.Mutex
|
|
|
|
sumWeight weightType
|
|
startTime time.Time
|
|
}
|
|
|
|
weightType uint16
|
|
)
|
|
|
|
const maxWeight weightType = 100
|
|
|
|
func New(cfg *Config) *LoadBalancer {
|
|
lb := &LoadBalancer{
|
|
Config: new(Config),
|
|
pool: newPool(),
|
|
task: task.DummyTask(),
|
|
}
|
|
lb.UpdateConfigIfNeeded(cfg)
|
|
return lb
|
|
}
|
|
|
|
// Start implements task.TaskStarter.
|
|
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error {
|
|
lb.startTime = time.Now()
|
|
lb.task = routeSubtask
|
|
lb.task.OnFinished("loadbalancer cleanup", func() {
|
|
if lb.impl != nil {
|
|
lb.pool.RangeAll(func(k string, v *Server) {
|
|
lb.impl.OnRemoveServer(v)
|
|
})
|
|
}
|
|
lb.pool.Clear()
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Finish implements task.TaskFinisher.
|
|
func (lb *LoadBalancer) Finish(reason any) {
|
|
lb.task.Finish(reason)
|
|
}
|
|
|
|
func (lb *LoadBalancer) updateImpl() {
|
|
switch lb.Mode {
|
|
case Unset, RoundRobin:
|
|
lb.impl = lb.newRoundRobin()
|
|
case LeastConn:
|
|
lb.impl = lb.newLeastConn()
|
|
case IPHash:
|
|
lb.impl = lb.newIPHash()
|
|
default: // should happen in test only
|
|
lb.impl = lb.newRoundRobin()
|
|
}
|
|
lb.pool.RangeAll(func(_ string, srv *Server) {
|
|
lb.impl.OnAddServer(srv)
|
|
})
|
|
}
|
|
|
|
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
|
if cfg != nil {
|
|
lb.poolMu.Lock()
|
|
defer lb.poolMu.Unlock()
|
|
|
|
lb.Link = cfg.Link
|
|
|
|
if lb.Mode == Unset && cfg.Mode != Unset {
|
|
lb.Mode = cfg.Mode
|
|
if !lb.Mode.ValidateUpdate() {
|
|
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode)
|
|
}
|
|
lb.updateImpl()
|
|
}
|
|
|
|
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
|
|
lb.Options = cfg.Options
|
|
}
|
|
}
|
|
|
|
if lb.impl == nil {
|
|
lb.updateImpl()
|
|
}
|
|
}
|
|
|
|
func (lb *LoadBalancer) AddServer(srv *Server) {
|
|
lb.poolMu.Lock()
|
|
defer lb.poolMu.Unlock()
|
|
|
|
if lb.pool.Has(srv.Name) {
|
|
old, _ := lb.pool.Load(srv.Name)
|
|
lb.sumWeight -= old.Weight
|
|
lb.impl.OnRemoveServer(old)
|
|
}
|
|
lb.pool.Store(srv.Name, srv)
|
|
lb.sumWeight += srv.Weight
|
|
|
|
lb.rebalance()
|
|
lb.impl.OnAddServer(srv)
|
|
logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
|
|
}
|
|
|
|
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
|
lb.poolMu.Lock()
|
|
defer lb.poolMu.Unlock()
|
|
|
|
if !lb.pool.Has(srv.Name) {
|
|
return
|
|
}
|
|
|
|
lb.pool.Delete(srv.Name)
|
|
|
|
lb.sumWeight -= srv.Weight
|
|
lb.rebalance()
|
|
lb.impl.OnRemoveServer(srv)
|
|
|
|
if lb.pool.Size() == 0 {
|
|
lb.task.Finish("no server left")
|
|
logger.Infof("loadbalancer %s stopped", lb.Link)
|
|
return
|
|
}
|
|
|
|
logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
|
|
}
|
|
|
|
func (lb *LoadBalancer) rebalance() {
|
|
if lb.sumWeight == maxWeight {
|
|
return
|
|
}
|
|
if lb.pool.Size() == 0 {
|
|
return
|
|
}
|
|
if lb.sumWeight == 0 { // distribute evenly
|
|
weightEach := maxWeight / weightType(lb.pool.Size())
|
|
remainder := maxWeight % weightType(lb.pool.Size())
|
|
lb.pool.RangeAll(func(_ string, s *Server) {
|
|
s.Weight = weightEach
|
|
lb.sumWeight += weightEach
|
|
if remainder > 0 {
|
|
s.Weight++
|
|
remainder--
|
|
}
|
|
})
|
|
return
|
|
}
|
|
|
|
// scale evenly
|
|
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
|
lb.sumWeight = 0
|
|
|
|
lb.pool.RangeAll(func(_ string, s *Server) {
|
|
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
|
lb.sumWeight += s.Weight
|
|
})
|
|
|
|
delta := maxWeight - lb.sumWeight
|
|
if delta == 0 {
|
|
return
|
|
}
|
|
lb.pool.Range(func(_ string, s *Server) bool {
|
|
if delta == 0 {
|
|
return false
|
|
}
|
|
if delta > 0 {
|
|
s.Weight++
|
|
lb.sumWeight++
|
|
delta--
|
|
} else {
|
|
s.Weight--
|
|
lb.sumWeight--
|
|
delta++
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|
srvs := lb.availServers()
|
|
if len(srvs) == 0 {
|
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
if r.Header.Get(common.HeaderCheckRedirect) != "" {
|
|
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
|
|
defer cancel()
|
|
// send dummy request to wake all servers
|
|
var dummyRW *DummyResponseWriter
|
|
for _, srv := range srvs {
|
|
// wake only if server implements Waker
|
|
_, ok := srv.handler.(idlewatcher.Waker)
|
|
if !ok {
|
|
continue
|
|
}
|
|
wakeReq := r.Clone(ctx)
|
|
srv.ServeHTTP(dummyRW, wakeReq)
|
|
}
|
|
}
|
|
lb.impl.ServeHTTP(srvs, rw, r)
|
|
}
|
|
|
|
func (lb *LoadBalancer) Uptime() time.Duration {
|
|
return time.Since(lb.startTime)
|
|
}
|
|
|
|
// MarshalJSON implements health.HealthMonitor.
|
|
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
|
extra := make(map[string]any)
|
|
lb.pool.RangeAll(func(k string, v *Server) {
|
|
extra[v.Name] = v.healthMon
|
|
})
|
|
|
|
return (&health.JSONRepresentation{
|
|
Name: lb.Name(),
|
|
Status: lb.Status(),
|
|
Started: lb.startTime,
|
|
Uptime: lb.Uptime(),
|
|
Extra: map[string]any{
|
|
"config": lb.Config,
|
|
"pool": extra,
|
|
},
|
|
}).MarshalJSON()
|
|
}
|
|
|
|
// Name implements health.HealthMonitor.
|
|
func (lb *LoadBalancer) Name() string {
|
|
return lb.Link
|
|
}
|
|
|
|
// Status implements health.HealthMonitor.
|
|
func (lb *LoadBalancer) Status() health.Status {
|
|
if lb.pool.Size() == 0 {
|
|
return health.StatusUnknown
|
|
}
|
|
if len(lb.availServers()) == 0 {
|
|
return health.StatusUnhealthy
|
|
}
|
|
return health.StatusHealthy
|
|
}
|
|
|
|
// String implements health.HealthMonitor.
|
|
func (lb *LoadBalancer) String() string {
|
|
return lb.Name()
|
|
}
|
|
|
|
func (lb *LoadBalancer) availServers() []*Server {
|
|
avail := make([]*Server, 0, lb.pool.Size())
|
|
lb.pool.RangeAll(func(_ string, srv *Server) {
|
|
if srv.Status().Good() {
|
|
avail = append(avail, srv)
|
|
}
|
|
})
|
|
return avail
|
|
}
|