mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-01 13:04:25 +02:00
added round_robin, least_conn and ip_hash load balance support, small refactoring
This commit is contained in:
parent
1797896fa6
commit
5c40f4aa84
24 changed files with 739 additions and 64 deletions
|
@ -212,7 +212,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
cfg.proxyProviders.Store(p.GetName(), p)
|
cfg.proxyProviders.Store(p.GetName(), p)
|
||||||
b.Add(p.LoadRoutes().Subject(dockerHost))
|
b.Add(p.LoadRoutes().Subject(p.GetName()))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package idlewatcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -19,10 +18,6 @@ type Waker struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
|
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
|
||||||
tr := &http.Transport{}
|
|
||||||
if w.NoTLSVerify {
|
|
||||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
||||||
}
|
|
||||||
orig := rp.ServeHTTP
|
orig := rp.ServeHTTP
|
||||||
// workaround for stopped containers port become zero
|
// workaround for stopped containers port become zero
|
||||||
rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) {
|
rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -41,7 +36,7 @@ func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
|
||||||
watcher: w,
|
watcher: w,
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
Timeout: 1 * time.Second,
|
Timeout: 1 * time.Second,
|
||||||
Transport: tr,
|
Transport: rp.Transport,
|
||||||
},
|
},
|
||||||
rp: rp,
|
rp: rp,
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,6 +65,10 @@ func (b Builder) To(ptr *NestedError) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b Builder) String() string {
|
||||||
|
return b.Build().String()
|
||||||
|
}
|
||||||
|
|
||||||
func (b Builder) HasError() bool {
|
func (b Builder) HasError() bool {
|
||||||
return len(b.errors) > 0
|
return len(b.errors) > 0
|
||||||
}
|
}
|
||||||
|
|
33
internal/net/http/loadbalancer/ip_hash.go
Normal file
33
internal/net/http/loadbalancer/ip_hash.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"hash/fnv"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ipHash struct{ *LoadBalancer }
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
|
||||||
|
func (ipHash) OnAddServer(srv *Server) {}
|
||||||
|
func (ipHash) OnRemoveServer(srv *Server) {}
|
||||||
|
|
||||||
|
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||||
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||||
|
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||||
|
if !impl.pool[idx].available.Load() {
|
||||||
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
impl.pool[idx].handler.ServeHTTP(rw, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashIP(ip string) uint32 {
|
||||||
|
h := fnv.New32a()
|
||||||
|
h.Write([]byte(ip))
|
||||||
|
return h.Sum32()
|
||||||
|
}
|
53
internal/net/http/loadbalancer/least_conn.go
Normal file
53
internal/net/http/loadbalancer/least_conn.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
|
)
|
||||||
|
|
||||||
|
type leastConn struct {
|
||||||
|
*LoadBalancer
|
||||||
|
nConn F.Map[*Server, *atomic.Int64]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) newLeastConn() impl {
|
||||||
|
return &leastConn{
|
||||||
|
LoadBalancer: lb,
|
||||||
|
nConn: F.NewMapOf[*Server, *atomic.Int64](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (impl *leastConn) OnAddServer(srv *Server) {
|
||||||
|
impl.nConn.Store(srv, new(atomic.Int64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (impl *leastConn) OnRemoveServer(srv *Server) {
|
||||||
|
impl.nConn.Delete(srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||||
|
srv := srvs[0]
|
||||||
|
minConn, ok := impl.nConn.Load(srv)
|
||||||
|
if !ok {
|
||||||
|
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||||
|
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(srvs); i++ {
|
||||||
|
nConn, ok := impl.nConn.Load(srvs[i])
|
||||||
|
if !ok {
|
||||||
|
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||||
|
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if nConn.Load() < minConn.Load() {
|
||||||
|
minConn = nConn
|
||||||
|
srv = srvs[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
minConn.Add(1)
|
||||||
|
srv.handler.ServeHTTP(rw, r)
|
||||||
|
minConn.Add(-1)
|
||||||
|
}
|
241
internal/net/http/loadbalancer/loadbalancer.go
Normal file
241
internal/net/http/loadbalancer/loadbalancer.go
Normal file
|
@ -0,0 +1,241 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-acme/lego/v4/log"
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: stats of each server
|
||||||
|
// TODO: support weighted mode
|
||||||
|
type (
|
||||||
|
impl interface {
|
||||||
|
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
||||||
|
OnAddServer(srv *Server)
|
||||||
|
OnRemoveServer(srv *Server)
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Link string
|
||||||
|
Mode Mode
|
||||||
|
Weight weightType
|
||||||
|
}
|
||||||
|
LoadBalancer struct {
|
||||||
|
impl
|
||||||
|
Config
|
||||||
|
|
||||||
|
pool servers
|
||||||
|
poolMu sync.RWMutex
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
sumWeight weightType
|
||||||
|
}
|
||||||
|
|
||||||
|
weightType uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxWeight weightType = 100
|
||||||
|
|
||||||
|
func New(cfg Config) *LoadBalancer {
|
||||||
|
lb := &LoadBalancer{Config: cfg, pool: servers{}}
|
||||||
|
mode := cfg.Mode
|
||||||
|
if !cfg.Mode.ValidateUpdate() {
|
||||||
|
logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode)
|
||||||
|
}
|
||||||
|
switch mode {
|
||||||
|
case RoundRobin:
|
||||||
|
lb.impl = lb.newRoundRobin()
|
||||||
|
case LeastConn:
|
||||||
|
lb.impl = lb.newLeastConn()
|
||||||
|
case IPHash:
|
||||||
|
lb.impl = lb.newIPHash()
|
||||||
|
default: // should happen in test only
|
||||||
|
lb.impl = lb.newRoundRobin()
|
||||||
|
}
|
||||||
|
return lb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) AddServer(srv *Server) {
|
||||||
|
lb.poolMu.Lock()
|
||||||
|
defer lb.poolMu.Unlock()
|
||||||
|
|
||||||
|
lb.pool = append(lb.pool, srv)
|
||||||
|
lb.sumWeight += srv.Weight
|
||||||
|
|
||||||
|
lb.impl.OnAddServer(srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||||
|
lb.poolMu.RLock()
|
||||||
|
defer lb.poolMu.RUnlock()
|
||||||
|
|
||||||
|
lb.impl.OnRemoveServer(srv)
|
||||||
|
|
||||||
|
for i, s := range lb.pool {
|
||||||
|
if s == srv {
|
||||||
|
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lb.IsEmpty() {
|
||||||
|
lb.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) IsEmpty() bool {
|
||||||
|
return len(lb.pool) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) Rebalance() {
|
||||||
|
if lb.sumWeight == maxWeight {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lb.sumWeight == 0 { // distribute evenly
|
||||||
|
weightEach := maxWeight / weightType(len(lb.pool))
|
||||||
|
remainer := maxWeight % weightType(len(lb.pool))
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
s.Weight = weightEach
|
||||||
|
lb.sumWeight += weightEach
|
||||||
|
if remainer > 0 {
|
||||||
|
s.Weight++
|
||||||
|
}
|
||||||
|
remainer--
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// scale evenly
|
||||||
|
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
||||||
|
lb.sumWeight = 0
|
||||||
|
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
||||||
|
lb.sumWeight += s.Weight
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := maxWeight - lb.sumWeight
|
||||||
|
if delta == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
if delta == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if delta > 0 {
|
||||||
|
s.Weight++
|
||||||
|
lb.sumWeight++
|
||||||
|
delta--
|
||||||
|
} else {
|
||||||
|
s.Weight--
|
||||||
|
lb.sumWeight--
|
||||||
|
delta++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
srvs := lb.availServers()
|
||||||
|
if len(srvs) == 0 {
|
||||||
|
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lb.impl.ServeHTTP(srvs, rw, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) Start() {
|
||||||
|
if lb.IsEmpty() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
|
||||||
|
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
msg.Addf("%s: %d", s.Name, s.Weight)
|
||||||
|
}
|
||||||
|
lb.Rebalance()
|
||||||
|
inner := E.NewBuilder("After rebalancing")
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
inner.Addf("%s: %d", s.Name, s.Weight)
|
||||||
|
}
|
||||||
|
msg.Addf("%s", inner)
|
||||||
|
logger.Warn(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if lb.sumWeight != 0 {
|
||||||
|
log.Warnf("Weighted mode not supported yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch lb.Mode {
|
||||||
|
case RoundRobin:
|
||||||
|
lb.impl = lb.newRoundRobin()
|
||||||
|
case LeastConn:
|
||||||
|
lb.impl = lb.newLeastConn()
|
||||||
|
case IPHash:
|
||||||
|
lb.impl = lb.newIPHash()
|
||||||
|
}
|
||||||
|
|
||||||
|
lb.done = make(chan struct{}, 1)
|
||||||
|
lb.ctx, lb.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
updateAll := func() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(lb.pool))
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
go func(s *Server) {
|
||||||
|
defer wg.Done()
|
||||||
|
s.checkUpdateAvail(lb.ctx)
|
||||||
|
}(s)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer lb.cancel()
|
||||||
|
defer close(lb.done)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
updateAll()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-lb.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
lb.poolMu.RLock()
|
||||||
|
updateAll()
|
||||||
|
lb.poolMu.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) Stop() {
|
||||||
|
if lb.impl == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lb.cancel()
|
||||||
|
|
||||||
|
<-lb.done
|
||||||
|
lb.pool = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) availServers() servers {
|
||||||
|
lb.poolMu.Lock()
|
||||||
|
defer lb.poolMu.Unlock()
|
||||||
|
|
||||||
|
avail := servers{}
|
||||||
|
for _, s := range lb.pool {
|
||||||
|
if s.available.Load() {
|
||||||
|
avail = append(avail, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return avail
|
||||||
|
}
|
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRebalance(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
t.Run("zero", func(t *testing.T) {
|
||||||
|
lb := New(Config{})
|
||||||
|
for range 10 {
|
||||||
|
lb.AddServer(&Server{})
|
||||||
|
}
|
||||||
|
lb.Rebalance()
|
||||||
|
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||||
|
})
|
||||||
|
t.Run("less", func(t *testing.T) {
|
||||||
|
lb := New(Config{})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||||
|
lb.Rebalance()
|
||||||
|
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||||
|
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||||
|
})
|
||||||
|
t.Run("more", func(t *testing.T) {
|
||||||
|
lb := New(Config{})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||||
|
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||||
|
lb.Rebalance()
|
||||||
|
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||||
|
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||||
|
})
|
||||||
|
}
|
5
internal/net/http/loadbalancer/logger.go
Normal file
5
internal/net/http/loadbalancer/logger.go
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
var logger = logrus.WithField("module", "load_balancer")
|
29
internal/net/http/loadbalancer/mode.go
Normal file
29
internal/net/http/loadbalancer/mode.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
U "github.com/yusing/go-proxy/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Mode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoundRobin Mode = "roundrobin"
|
||||||
|
LeastConn Mode = "leastconn"
|
||||||
|
IPHash Mode = "iphash"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (mode *Mode) ValidateUpdate() bool {
|
||||||
|
switch U.ToLowerNoSnake(string(*mode)) {
|
||||||
|
case "", string(RoundRobin):
|
||||||
|
*mode = RoundRobin
|
||||||
|
return true
|
||||||
|
case string(LeastConn):
|
||||||
|
*mode = LeastConn
|
||||||
|
return true
|
||||||
|
case string(IPHash):
|
||||||
|
*mode = IPHash
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
*mode = RoundRobin
|
||||||
|
return false
|
||||||
|
}
|
22
internal/net/http/loadbalancer/round_robin.go
Normal file
22
internal/net/http/loadbalancer/round_robin.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundRobin struct {
|
||||||
|
index atomic.Uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||||
|
func (lb *roundRobin) OnAddServer(srv *Server) {}
|
||||||
|
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
|
||||||
|
|
||||||
|
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||||
|
index := lb.index.Add(1)
|
||||||
|
srvs[index%uint32(len(srvs))].handler.ServeHTTP(rw, r)
|
||||||
|
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||||
|
lb.index.Store(0)
|
||||||
|
}
|
||||||
|
}
|
67
internal/net/http/loadbalancer/server.go
Normal file
67
internal/net/http/loadbalancer/server.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package loadbalancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Server struct {
|
||||||
|
Name string
|
||||||
|
URL types.URL
|
||||||
|
Weight weightType
|
||||||
|
handler http.Handler
|
||||||
|
|
||||||
|
pinger *http.Client
|
||||||
|
available atomic.Bool
|
||||||
|
}
|
||||||
|
servers []*Server
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
|
||||||
|
srv := &Server{
|
||||||
|
Name: name,
|
||||||
|
URL: url,
|
||||||
|
Weight: weightType(weight),
|
||||||
|
handler: handler,
|
||||||
|
pinger: &http.Client{Timeout: 3 * time.Second},
|
||||||
|
}
|
||||||
|
srv.available.Store(true)
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) checkUpdateAvail(ctx context.Context) {
|
||||||
|
req, err := http.NewRequestWithContext(
|
||||||
|
ctx,
|
||||||
|
http.MethodHead,
|
||||||
|
srv.URL.String(),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to create request: ", err)
|
||||||
|
srv.available.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := srv.pinger.Do(req)
|
||||||
|
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
if !srv.available.Swap(true) {
|
||||||
|
logger.Infof("server %s is up", srv.Name)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
if srv.available.Swap(false) {
|
||||||
|
logger.Warnf("server %s is down: %s", srv.Name, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if srv.available.Swap(false) {
|
||||||
|
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) String() string {
|
||||||
|
return srv.Name
|
||||||
|
}
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -29,7 +29,6 @@ func init() {
|
||||||
"setxforwarded": SetXForwarded,
|
"setxforwarded": SetXForwarded,
|
||||||
"hidexforwarded": HideXForwarded,
|
"hidexforwarded": HideXForwarded,
|
||||||
"redirecthttp": RedirectHTTP,
|
"redirecthttp": RedirectHTTP,
|
||||||
"forwardauth": ForwardAuth.m,
|
|
||||||
"modifyresponse": ModifyResponse.m,
|
"modifyresponse": ModifyResponse.m,
|
||||||
"modifyrequest": ModifyRequest.m,
|
"modifyrequest": ModifyRequest.m,
|
||||||
"errorpage": CustomErrorPage,
|
"errorpage": CustomErrorPage,
|
||||||
|
@ -37,6 +36,10 @@ func init() {
|
||||||
"realip": RealIP.m,
|
"realip": RealIP.m,
|
||||||
"cloudflarerealip": CloudflareRealIP.m,
|
"cloudflarerealip": CloudflareRealIP.m,
|
||||||
"cidrwhitelist": CIDRWhiteList.m,
|
"cidrwhitelist": CIDRWhiteList.m,
|
||||||
|
|
||||||
|
// !experimental
|
||||||
|
"forwardauth": ForwardAuth.m,
|
||||||
|
"oauth2": OAuth2.m,
|
||||||
}
|
}
|
||||||
names := make(map[*Middleware][]string)
|
names := make(map[*Middleware][]string)
|
||||||
for name, m := range middlewares {
|
for name, m := range middlewares {
|
||||||
|
|
129
internal/net/http/middleware/oauth2.go
Normal file
129
internal/net/http/middleware/oauth2.go
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
type oAuth2 struct {
|
||||||
|
*oAuth2Opts
|
||||||
|
m *Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
type oAuth2Opts struct {
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
AuthURL string // Authorization Endpoint
|
||||||
|
TokenURL string // Token Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
var OAuth2 = &oAuth2{
|
||||||
|
m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||||
|
oauth := new(oAuth2)
|
||||||
|
oauth.m = &Middleware{
|
||||||
|
impl: oauth,
|
||||||
|
before: oauth.handleOAuth2,
|
||||||
|
}
|
||||||
|
oauth.oAuth2Opts = &oAuth2Opts{}
|
||||||
|
err := Deserialize(opts, oauth.oAuth2Opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b := E.NewBuilder("missing required fields")
|
||||||
|
optV := reflect.ValueOf(oauth.oAuth2Opts)
|
||||||
|
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
|
||||||
|
if optV.FieldByName(field.Name).Len() == 0 {
|
||||||
|
b.Add(E.Missing(field.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if b.HasError() {
|
||||||
|
return nil, b.Build().Subject("oAuth2")
|
||||||
|
}
|
||||||
|
return oauth.m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
||||||
|
// Check if the user is authenticated (you may use session, cookie, etc.)
|
||||||
|
if !userIsAuthenticated(r) {
|
||||||
|
// TODO: Redirect to OAuth2 auth URL
|
||||||
|
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
||||||
|
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you have a token in the query string, process it
|
||||||
|
if code := r.URL.Query().Get("code"); code != "" {
|
||||||
|
// Exchange the authorization code for a token here
|
||||||
|
// Use the TokenURL and authenticate the user
|
||||||
|
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save token and user info based on your requirements
|
||||||
|
saveToken(rw, token)
|
||||||
|
|
||||||
|
// Redirect to the originally requested URL
|
||||||
|
http.Redirect(rw, r, "/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If user is authenticated, go to the next handler
|
||||||
|
next(rw, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func userIsAuthenticated(r *http.Request) bool {
|
||||||
|
// Example: Check for a session or cookie
|
||||||
|
session, err := r.Cookie("session_token")
|
||||||
|
if err != nil || session.Value == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Validate the session_token if necessary
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) {
|
||||||
|
// Prepare the request body
|
||||||
|
data := url.Values{
|
||||||
|
"client_id": {opts.ClientID},
|
||||||
|
"client_secret": {opts.ClientSecret},
|
||||||
|
"code": {code},
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"redirect_uri": {requestUri},
|
||||||
|
}
|
||||||
|
resp, err := http.PostForm(opts.TokenURL, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to request token: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
||||||
|
}
|
||||||
|
// Decode the response
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode token response: %v", err)
|
||||||
|
}
|
||||||
|
return tokenResp.AccessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveToken(rw ResponseWriter, token string) {
|
||||||
|
// Example: Save token in cookie
|
||||||
|
http.SetCookie(rw, &http.Cookie{
|
||||||
|
Name: "auth_token",
|
||||||
|
Value: token,
|
||||||
|
// set other properties as necessary, such as Secure and HttpOnly
|
||||||
|
})
|
||||||
|
}
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed test_data/sample_headers.json
|
//go:embed test_data/sample_headers.json
|
||||||
|
@ -110,7 +111,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||||
} else {
|
} else {
|
||||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||||
}
|
}
|
||||||
rp := gphttp.NewReverseProxy(proxyURL, rr)
|
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr)
|
||||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||||
if setOptErr != nil {
|
if setOptErr != nil {
|
||||||
return nil, setOptErr
|
return nil, setOptErr
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http/httpguts"
|
"golang.org/x/net/http/httpguts"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
U "github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -86,7 +87,7 @@ type ReverseProxy struct {
|
||||||
|
|
||||||
ServeHTTP http.HandlerFunc
|
ServeHTTP http.HandlerFunc
|
||||||
|
|
||||||
TargetURL *url.URL
|
TargetURL types.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func singleJoiningSlash(a, b string) string {
|
func singleJoiningSlash(a, b string) string {
|
||||||
|
@ -144,7 +145,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
|
|
||||||
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
|
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
|
||||||
if transport == nil {
|
if transport == nil {
|
||||||
panic("nil transport")
|
panic("nil transport")
|
||||||
}
|
}
|
||||||
|
@ -263,7 +264,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriteRequestURL(outreq, p.TargetURL)
|
rewriteRequestURL(outreq, p.TargetURL.URL)
|
||||||
outreq.Close = false
|
outreq.Close = false
|
||||||
|
|
||||||
reqUpType := UpgradeType(outreq.Header)
|
reqUpType := UpgradeType(outreq.Header)
|
||||||
|
@ -348,18 +349,16 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
roundTripMutex.Unlock()
|
roundTripMutex.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.errorHandler(rw, outreq, err, false)
|
p.errorHandler(rw, outreq, err, false)
|
||||||
errMsg := err.Error()
|
|
||||||
res = &http.Response{
|
res = &http.Response{
|
||||||
Status: http.StatusText(http.StatusBadGateway),
|
Status: http.StatusText(http.StatusBadGateway),
|
||||||
StatusCode: http.StatusBadGateway,
|
StatusCode: http.StatusBadGateway,
|
||||||
Proto: outreq.Proto,
|
Proto: outreq.Proto,
|
||||||
ProtoMajor: outreq.ProtoMajor,
|
ProtoMajor: outreq.ProtoMajor,
|
||||||
ProtoMinor: outreq.ProtoMinor,
|
ProtoMinor: outreq.ProtoMinor,
|
||||||
Header: make(http.Header),
|
Header: make(http.Header),
|
||||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||||
Request: outreq,
|
Request: outreq,
|
||||||
ContentLength: int64(len(errMsg)),
|
TLS: outreq.TLS,
|
||||||
TLS: outreq.TLS,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
24
internal/net/types/url.go
Normal file
24
internal/net/types/url.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import "net/url"
|
||||||
|
|
||||||
|
type URL struct{ *url.URL }
|
||||||
|
|
||||||
|
func NewURL(url *url.URL) URL {
|
||||||
|
return URL{url}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u URL) String() string {
|
||||||
|
if u.URL == nil {
|
||||||
|
return "nil"
|
||||||
|
}
|
||||||
|
return u.URL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u URL) MarshalText() (text []byte, err error) {
|
||||||
|
return []byte(u.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u URL) Equals(other URL) bool {
|
||||||
|
return u.URL == other.URL || u.String() == other.String()
|
||||||
|
}
|
|
@ -7,6 +7,8 @@ import (
|
||||||
|
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||||
|
net "github.com/yusing/go-proxy/internal/net/types"
|
||||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||||
"github.com/yusing/go-proxy/internal/types"
|
"github.com/yusing/go-proxy/internal/types"
|
||||||
)
|
)
|
||||||
|
@ -15,9 +17,10 @@ type (
|
||||||
ReverseProxyEntry struct { // real model after validation
|
ReverseProxyEntry struct { // real model after validation
|
||||||
Alias T.Alias
|
Alias T.Alias
|
||||||
Scheme T.Scheme
|
Scheme T.Scheme
|
||||||
URL *url.URL
|
URL net.URL
|
||||||
NoTLSVerify bool
|
NoTLSVerify bool
|
||||||
PathPatterns T.PathPatterns
|
PathPatterns T.PathPatterns
|
||||||
|
LoadBalance loadbalancer.Config
|
||||||
Middlewares D.NestedLabelMap
|
Middlewares D.NestedLabelMap
|
||||||
|
|
||||||
/* Docker only */
|
/* Docker only */
|
||||||
|
@ -47,6 +50,10 @@ func (rp *ReverseProxyEntry) IsDocker() bool {
|
||||||
return rp.DockerHost != ""
|
return rp.DockerHost != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxyEntry) IsZeroPort() bool {
|
||||||
|
return rp.URL.Port() == "0"
|
||||||
|
}
|
||||||
|
|
||||||
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
|
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
|
||||||
m.FillMissingFields()
|
m.FillMissingFields()
|
||||||
|
|
||||||
|
@ -107,9 +114,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
|
||||||
return &ReverseProxyEntry{
|
return &ReverseProxyEntry{
|
||||||
Alias: T.NewAlias(m.Alias),
|
Alias: T.NewAlias(m.Alias),
|
||||||
Scheme: s,
|
Scheme: s,
|
||||||
URL: url,
|
URL: net.NewURL(url),
|
||||||
NoTLSVerify: m.NoTLSVerify,
|
NoTLSVerify: m.NoTLSVerify,
|
||||||
PathPatterns: pathPatterns,
|
PathPatterns: pathPatterns,
|
||||||
|
LoadBalance: m.LoadBalance,
|
||||||
Middlewares: m.Middlewares,
|
Middlewares: m.Middlewares,
|
||||||
IdleTimeout: idleTimeout,
|
IdleTimeout: idleTimeout,
|
||||||
WakeTimeout: wakeTimeout,
|
WakeTimeout: wakeTimeout,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -13,6 +12,7 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
. "github.com/yusing/go-proxy/internal/net/http"
|
. "github.com/yusing/go-proxy/internal/net/http"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||||
P "github.com/yusing/go-proxy/internal/proxy"
|
P "github.com/yusing/go-proxy/internal/proxy"
|
||||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||||
|
@ -21,16 +21,14 @@ import (
|
||||||
|
|
||||||
type (
|
type (
|
||||||
HTTPRoute struct {
|
HTTPRoute struct {
|
||||||
Alias PT.Alias `json:"alias"`
|
*P.ReverseProxyEntry
|
||||||
TargetURL *URL `json:"target_url"`
|
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"`
|
||||||
PathPatterns PT.PathPatterns `json:"path_patterns"`
|
|
||||||
|
|
||||||
entry *P.ReverseProxyEntry
|
server *loadbalancer.Server
|
||||||
handler http.Handler
|
handler http.Handler
|
||||||
rp *ReverseProxy
|
rp *ReverseProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
URL url.URL
|
|
||||||
SubdomainKey = PT.Alias
|
SubdomainKey = PT.Alias
|
||||||
|
|
||||||
ReverseProxyHandler struct {
|
ReverseProxyHandler struct {
|
||||||
|
@ -80,11 +78,8 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
||||||
defer httpRoutesMu.Unlock()
|
defer httpRoutesMu.Unlock()
|
||||||
|
|
||||||
r := &HTTPRoute{
|
r := &HTTPRoute{
|
||||||
Alias: entry.Alias,
|
ReverseProxyEntry: entry,
|
||||||
TargetURL: (*URL)(entry.URL),
|
rp: rp,
|
||||||
PathPatterns: entry.PathPatterns,
|
|
||||||
entry: entry,
|
|
||||||
rp: rp,
|
|
||||||
}
|
}
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
@ -101,18 +96,19 @@ func (r *HTTPRoute) Start() E.NestedError {
|
||||||
httpRoutesMu.Lock()
|
httpRoutesMu.Lock()
|
||||||
defer httpRoutesMu.Unlock()
|
defer httpRoutesMu.Unlock()
|
||||||
|
|
||||||
if r.entry.UseIdleWatcher() {
|
switch {
|
||||||
watcher, err := idlewatcher.Register(r.entry)
|
case r.UseIdleWatcher():
|
||||||
|
watcher, err := idlewatcher.Register(r.ReverseProxyEntry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.handler = idlewatcher.NewWaker(watcher, r.rp)
|
r.handler = idlewatcher.NewWaker(watcher, r.rp)
|
||||||
} else if r.entry.URL.Port() == "0" ||
|
case r.IsZeroPort() ||
|
||||||
r.entry.IsDocker() && !r.entry.ContainerRunning {
|
r.IsDocker() && !r.ContainerRunning:
|
||||||
return nil
|
return nil
|
||||||
} else if len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/" {
|
case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
|
||||||
r.handler = ReverseProxyHandler{r.rp}
|
r.handler = ReverseProxyHandler{r.rp}
|
||||||
} else {
|
default:
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
for _, p := range r.PathPatterns {
|
for _, p := range r.PathPatterns {
|
||||||
mux.HandleFunc(string(p), r.rp.ServeHTTP)
|
mux.HandleFunc(string(p), r.rp.ServeHTTP)
|
||||||
|
@ -120,7 +116,25 @@ func (r *HTTPRoute) Start() E.NestedError {
|
||||||
r.handler = mux
|
r.handler = mux
|
||||||
}
|
}
|
||||||
|
|
||||||
httpRoutes.Store(string(r.Alias), r)
|
if r.LoadBalance.Link == "" {
|
||||||
|
httpRoutes.Store(string(r.Alias), r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var lb *loadbalancer.LoadBalancer
|
||||||
|
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
|
||||||
|
if ok {
|
||||||
|
lb = linked.LoadBalancer
|
||||||
|
} else {
|
||||||
|
lb = loadbalancer.New(r.LoadBalance)
|
||||||
|
lb.Start()
|
||||||
|
linked = &HTTPRoute{
|
||||||
|
LoadBalancer: lb,
|
||||||
|
handler: lb,
|
||||||
|
}
|
||||||
|
httpRoutes.Store(string(r.LoadBalance.Link), linked)
|
||||||
|
}
|
||||||
|
lb.AddServer(loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,7 +151,18 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.handler = nil
|
r.handler = nil
|
||||||
httpRoutes.Delete(string(r.Alias))
|
|
||||||
|
if r.server != nil {
|
||||||
|
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
|
||||||
|
if ok {
|
||||||
|
linked.LoadBalancer.RemoveServer(r.server)
|
||||||
|
}
|
||||||
|
if linked.LoadBalancer.IsEmpty() {
|
||||||
|
httpRoutes.Delete(string(r.LoadBalance.Link))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
httpRoutes.Delete(string(r.Alias))
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,14 +170,6 @@ func (r *HTTPRoute) Started() bool {
|
||||||
return r.handler != nil
|
return r.handler != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *URL) String() string {
|
|
||||||
return (*url.URL)(u).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *URL) MarshalText() (text []byte, err error) {
|
|
||||||
return []byte(u.String()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
mux, err := findMuxFunc(r.Host)
|
mux, err := findMuxFunc(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
. "github.com/yusing/go-proxy/internal/common"
|
. "github.com/yusing/go-proxy/internal/common"
|
||||||
D "github.com/yusing/go-proxy/internal/docker"
|
D "github.com/yusing/go-proxy/internal/docker"
|
||||||
H "github.com/yusing/go-proxy/internal/homepage"
|
H "github.com/yusing/go-proxy/internal/homepage"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,14 +16,15 @@ type (
|
||||||
RawEntry struct {
|
RawEntry struct {
|
||||||
// raw entry object before validation
|
// raw entry object before validation
|
||||||
// loaded from docker labels or yaml file
|
// loaded from docker labels or yaml file
|
||||||
Alias string `yaml:"-" json:"-"`
|
Alias string `yaml:"-" json:"-"`
|
||||||
Scheme string `yaml:"scheme" json:"scheme"`
|
Scheme string `yaml:"scheme" json:"scheme"`
|
||||||
Host string `yaml:"host" json:"host"`
|
Host string `yaml:"host" json:"host"`
|
||||||
Port string `yaml:"port" json:"port"`
|
Port string `yaml:"port" json:"port"`
|
||||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only
|
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only
|
||||||
PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only
|
PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only
|
||||||
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"`
|
LoadBalance loadbalancer.Config `yaml:"load_balance" json:"load_balance"`
|
||||||
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage"`
|
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"`
|
||||||
|
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage,omitempty"`
|
||||||
|
|
||||||
/* Docker only */
|
/* Docker only */
|
||||||
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
|
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
|
||||||
|
|
Loading…
Add table
Reference in a new issue