performance improvement and small fix on loadbalancer

This commit is contained in:
yusing 2024-10-09 18:10:51 +08:00
parent 5c40f4aa84
commit d91b66ae87
11 changed files with 77 additions and 50 deletions

View file

@ -11,7 +11,6 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
@ -137,13 +136,11 @@ func main() {
signal.Notify(sig, syscall.SIGHUP) signal.Notify(sig, syscall.SIGHUP)
autocert := cfg.GetAutoCertProvider() autocert := cfg.GetAutoCertProvider()
if autocert != nil { if autocert != nil {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
onShutdown.Add(cancel)
if err := autocert.Setup(ctx); err != nil { if err := autocert.Setup(ctx); err != nil {
l.Fatal(err) l.Fatal(err)
} else {
onShutdown.Add(cancel)
} }
} else { } else {
l.Info("autocert not configured") l.Info("autocert not configured")
@ -179,19 +176,15 @@ func main() {
// grafully shutdown // grafully shutdown
logrus.Info("shutting down") logrus.Info("shutting down")
done := make(chan struct{}, 1) done := make(chan struct{}, 1)
currentIdx := 0
var wg sync.WaitGroup go func() {
wg.Add(onShutdown.Size()) onShutdown.ForEach(func(f func()) {
onShutdown.ForEach(func(f func()) {
go func() {
l.Debugf("waiting for %s to complete...", funcName(f)) l.Debugf("waiting for %s to complete...", funcName(f))
f() f()
currentIdx++
l.Debugf("%s done", funcName(f)) l.Debugf("%s done", funcName(f))
wg.Done() })
}()
})
go func() {
wg.Wait()
close(done) close(done)
}() }()
@ -201,9 +194,9 @@ func main() {
logrus.Info("shutdown complete") logrus.Info("shutdown complete")
case <-timeout: case <-timeout:
logrus.Info("timeout waiting for shutdown") logrus.Info("timeout waiting for shutdown")
onShutdown.ForEach(func(f func()) { for i := currentIdx; i < onShutdown.Size(); i++ {
l.Warnf("%s() is still running", funcName(f)) l.Warnf("%s() is still running", funcName(onShutdown.Get(i)))
}) }
} }
} }

View file

@ -220,7 +220,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("errors in %s these providers", action) errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) {
if err := do(p); err.HasError() { if err := do(p); err.HasError() {
errors.Add(err.Subject(p)) errors.Add(err.Subject(p))
} }

View file

@ -133,7 +133,7 @@ func ConnectClient(host string) (Client, E.NestedError) {
} }
func CloseAllClients() { func CloseAllClients() {
clientMap.RangeAll(func(_ string, c Client) { clientMap.RangeAllParallel(func(_ string, c Client) {
c.Client.Close() c.Client.Close()
}) })
clientMap.Clear() clientMap.Clear()

View file

@ -46,7 +46,7 @@ func New(cfg Config) *LoadBalancer {
lb := &LoadBalancer{Config: cfg, pool: servers{}} lb := &LoadBalancer{Config: cfg, pool: servers{}}
mode := cfg.Mode mode := cfg.Mode
if !cfg.Mode.ValidateUpdate() { if !cfg.Mode.ValidateUpdate() {
logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode) logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %s", cfg.Link, mode, cfg.Mode)
} }
switch mode { switch mode {
case RoundRobin: case RoundRobin:
@ -69,6 +69,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.sumWeight += srv.Weight lb.sumWeight += srv.Weight
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool))
} }
func (lb *LoadBalancer) RemoveServer(srv *Server) { func (lb *LoadBalancer) RemoveServer(srv *Server) {
@ -85,7 +86,11 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
} }
if lb.IsEmpty() { if lb.IsEmpty() {
lb.Stop() lb.Stop()
return
} }
lb.Rebalance()
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool))
} }
func (lb *LoadBalancer) IsEmpty() bool { func (lb *LoadBalancer) IsEmpty() bool {
@ -98,14 +103,14 @@ func (lb *LoadBalancer) Rebalance() {
} }
if lb.sumWeight == 0 { // distribute evenly if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(len(lb.pool)) weightEach := maxWeight / weightType(len(lb.pool))
remainer := maxWeight % weightType(len(lb.pool)) remainder := maxWeight % weightType(len(lb.pool))
for _, s := range lb.pool { for _, s := range lb.pool {
s.Weight = weightEach s.Weight = weightEach
lb.sumWeight += weightEach lb.sumWeight += weightEach
if remainer > 0 { if remainder > 0 {
s.Weight++ s.Weight++
remainder--
} }
remainer--
} }
return return
} }
@ -149,17 +154,13 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
} }
func (lb *LoadBalancer) Start() { func (lb *LoadBalancer) Start() {
if lb.IsEmpty() {
return
}
if lb.sumWeight != 0 && lb.sumWeight != maxWeight { if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight) msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
for _, s := range lb.pool { for _, s := range lb.pool {
msg.Addf("%s: %d", s.Name, s.Weight) msg.Addf("%s: %d", s.Name, s.Weight)
} }
lb.Rebalance() lb.Rebalance()
inner := E.NewBuilder("After rebalancing") inner := E.NewBuilder("after rebalancing")
for _, s := range lb.pool { for _, s := range lb.pool {
inner.Addf("%s: %d", s.Name, s.Weight) inner.Addf("%s: %d", s.Name, s.Weight)
} }
@ -168,22 +169,16 @@ func (lb *LoadBalancer) Start() {
} }
if lb.sumWeight != 0 { if lb.sumWeight != 0 {
log.Warnf("Weighted mode not supported yet") log.Warnf("weighted mode not supported yet")
}
switch lb.Mode {
case RoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
lb.impl = lb.newLeastConn()
case IPHash:
lb.impl = lb.newIPHash()
} }
lb.done = make(chan struct{}, 1) lb.done = make(chan struct{}, 1)
lb.ctx, lb.cancel = context.WithCancel(context.Background()) lb.ctx, lb.cancel = context.WithCancel(context.Background())
updateAll := func() { updateAll := func() {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(lb.pool)) wg.Add(len(lb.pool))
for _, s := range lb.pool { for _, s := range lb.pool {
@ -195,6 +190,8 @@ func (lb *LoadBalancer) Start() {
wg.Wait() wg.Wait()
} }
logger.Debugf("loadbalancer %s started", lb.Link)
go func() { go func() {
defer lb.cancel() defer lb.cancel()
defer close(lb.done) defer close(lb.done)
@ -208,16 +205,14 @@ func (lb *LoadBalancer) Start() {
case <-lb.ctx.Done(): case <-lb.ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
lb.poolMu.RLock()
updateAll() updateAll()
lb.poolMu.RUnlock()
} }
} }
}() }()
} }
func (lb *LoadBalancer) Stop() { func (lb *LoadBalancer) Stop() {
if lb.impl == nil { if lb.cancel == nil {
return return
} }
@ -225,6 +220,8 @@ func (lb *LoadBalancer) Stop() {
<-lb.done <-lb.done
lb.pool = nil lb.pool = nil
logger.Debugf("loadbalancer %s stopped", lb.Link)
} }
func (lb *LoadBalancer) availServers() servers { func (lb *LoadBalancer) availServers() servers {

View file

@ -26,7 +26,7 @@ func NewServer(name string, url types.URL, weight weightType, handler http.Handl
srv := &Server{ srv := &Server{
Name: name, Name: name,
URL: url, URL: url,
Weight: weightType(weight), Weight: weight,
handler: handler, handler: handler,
pinger: &http.Client{Timeout: 3 * time.Second}, pinger: &http.Client{Timeout: 3 * time.Second},
} }

View file

@ -52,12 +52,12 @@ func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult)
return return
} }
routes.RangeAll(func(_ string, v R.Route) { routes.RangeAllParallel(func(_ string, v R.Route) {
b.Add(v.Stop()) b.Add(v.Stop())
}) })
routes.Clear() routes.Clear()
newRoutes.RangeAll(func(_ string, v R.Route) { newRoutes.RangeAllParallel(func(_ string, v R.Route) {
b.Add(v.Start()) b.Add(v.Start())
}) })

View file

@ -114,7 +114,7 @@ func (p *Provider) StartAllRoutes() (res E.NestedError) {
nStarted := 0 nStarted := 0
nFailed := 0 nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) { p.routes.RangeAllParallel(func(alias string, r R.Route) {
if err := r.Start(); err.HasError() { if err := r.Start(); err.HasError() {
errors.Add(err.Subject(r)) errors.Add(err.Subject(r))
nFailed++ nFailed++
@ -138,7 +138,7 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) {
nStopped := 0 nStopped := 0
nFailed := 0 nFailed := 0
p.routes.RangeAll(func(alias string, r R.Route) { p.routes.RangeAllParallel(func(alias string, r R.Route) {
if err := r.Stop(); err.HasError() { if err := r.Stop(); err.HasError() {
errors.Add(err.Subject(r)) errors.Add(err.Subject(r))
nFailed++ nFailed++

View file

@ -22,7 +22,7 @@ import (
type ( type (
HTTPRoute struct { HTTPRoute struct {
*P.ReverseProxyEntry *P.ReverseProxyEntry
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"` LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer"`
server *loadbalancer.Server server *loadbalancer.Server
handler http.Handler handler http.Handler
@ -134,7 +134,8 @@ func (r *HTTPRoute) Start() E.NestedError {
} }
httpRoutes.Store(string(r.LoadBalance.Link), linked) httpRoutes.Store(string(r.LoadBalance.Link), linked)
} }
lb.AddServer(loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler)) r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler)
lb.AddServer(r.server)
return nil return nil
} }
@ -150,8 +151,6 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
waker.Unregister() waker.Unregister()
} }
r.handler = nil
if r.server != nil { if r.server != nil {
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link)) linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
if ok { if ok {
@ -160,9 +159,13 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
if linked.LoadBalancer.IsEmpty() { if linked.LoadBalancer.IsEmpty() {
httpRoutes.Delete(string(r.LoadBalance.Link)) httpRoutes.Delete(string(r.LoadBalance.Link))
} }
r.server = nil
} else { } else {
httpRoutes.Delete(string(r.Alias)) httpRoutes.Delete(string(r.Alias))
} }
r.handler = nil
return return
} }

View file

@ -108,7 +108,7 @@ func (route *UDPRoute) CloseListeners() {
route.listeningConn.Close() route.listeningConn.Close()
route.listeningConn = nil route.listeningConn = nil
} }
route.connMap.RangeAll(func(_ string, conn *UDPConn) { route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil { if err := conn.src.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err) route.l.Errorf("error closing src conn: %s", err)
} }

View file

@ -1,6 +1,8 @@
package functional package functional
import ( import (
"sync"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@ -75,6 +77,20 @@ func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) {
}) })
} }
func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) {
var wg sync.WaitGroup
wg.Add(m.Size())
m.Range(func(k KT, v VT) bool {
go func() {
do(k, v)
wg.Done()
}()
return true
})
wg.Wait()
}
func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) { func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) {
m.Range(func(k KT, v VT) bool { m.Range(func(k KT, v VT) bool {
if criteria(v) { if criteria(v) {

View file

@ -38,6 +38,10 @@ func (s *Slice[T]) Iterator() []T {
return s.s return s.s
} }
func (s *Slice[T]) Get(i int) T {
return s.s[i]
}
func (s *Slice[T]) Set(i int, v T) { func (s *Slice[T]) Set(i int, v T) {
s.s[i] = v s.s[i] = v
} }
@ -76,6 +80,20 @@ func (s *Slice[T]) SafePop() T {
return s.Pop() return s.Pop()
} }
func (s *Slice[T]) Remove(criteria func(T) bool) {
for i, v2 := range s.s {
if criteria(v2) {
s.s = append(s.s[:i], s.s[i+1:]...)
}
}
}
func (s *Slice[T]) SafeRemove(criteria func(T) bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.Remove(criteria)
}
func (s *Slice[T]) ForEach(do func(T)) { func (s *Slice[T]) ForEach(do func(T)) {
for _, v := range s.s { for _, v := range s.s {
do(v) do(v)