mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-01 04:54:26 +02:00
added round_robin, least_conn and ip_hash load balance support, small refactoring
This commit is contained in:
parent
1797896fa6
commit
5c40f4aa84
24 changed files with 739 additions and 64 deletions
|
@ -212,7 +212,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE
|
|||
continue
|
||||
}
|
||||
cfg.proxyProviders.Store(p.GetName(), p)
|
||||
b.Add(p.LoadRoutes().Subject(dockerHost))
|
||||
b.Add(p.LoadRoutes().Subject(p.GetName()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package idlewatcher
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
@ -19,10 +18,6 @@ type Waker struct {
|
|||
}
|
||||
|
||||
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
|
||||
tr := &http.Transport{}
|
||||
if w.NoTLSVerify {
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
orig := rp.ServeHTTP
|
||||
// workaround for stopped containers port become zero
|
||||
rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -41,7 +36,7 @@ func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
|
|||
watcher: w,
|
||||
client: &http.Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Transport: tr,
|
||||
Transport: rp.Transport,
|
||||
},
|
||||
rp: rp,
|
||||
}
|
||||
|
|
|
@ -65,6 +65,10 @@ func (b Builder) To(ptr *NestedError) {
|
|||
}
|
||||
}
|
||||
|
||||
func (b Builder) String() string {
|
||||
return b.Build().String()
|
||||
}
|
||||
|
||||
func (b Builder) HasError() bool {
|
||||
return len(b.errors) > 0
|
||||
}
|
||||
|
|
33
internal/net/http/loadbalancer/ip_hash.go
Normal file
33
internal/net/http/loadbalancer/ip_hash.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ipHash struct{ *LoadBalancer }
|
||||
|
||||
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
|
||||
func (ipHash) OnAddServer(srv *Server) {}
|
||||
func (ipHash) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
if !impl.pool[idx].available.Load() {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
}
|
||||
impl.pool[idx].handler.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func hashIP(ip string) uint32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(ip))
|
||||
return h.Sum32()
|
||||
}
|
53
internal/net/http/loadbalancer/least_conn.go
Normal file
53
internal/net/http/loadbalancer/least_conn.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type leastConn struct {
|
||||
*LoadBalancer
|
||||
nConn F.Map[*Server, *atomic.Int64]
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newLeastConn() impl {
|
||||
return &leastConn{
|
||||
LoadBalancer: lb,
|
||||
nConn: F.NewMapOf[*Server, *atomic.Int64](),
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnAddServer(srv *Server) {
|
||||
impl.nConn.Store(srv, new(atomic.Int64))
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnRemoveServer(srv *Server) {
|
||||
impl.nConn.Delete(srv)
|
||||
}
|
||||
|
||||
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for i := 1; i < len(srvs); i++ {
|
||||
nConn, ok := impl.nConn.Load(srvs[i])
|
||||
if !ok {
|
||||
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
if nConn.Load() < minConn.Load() {
|
||||
minConn = nConn
|
||||
srv = srvs[i]
|
||||
}
|
||||
}
|
||||
|
||||
minConn.Add(1)
|
||||
srv.handler.ServeHTTP(rw, r)
|
||||
minConn.Add(-1)
|
||||
}
|
241
internal/net/http/loadbalancer/loadbalancer.go
Normal file
241
internal/net/http/loadbalancer/loadbalancer.go
Normal file
|
@ -0,0 +1,241 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/log"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
// TODO: stats of each server
|
||||
// TODO: support weighted mode
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv *Server)
|
||||
OnRemoveServer(srv *Server)
|
||||
}
|
||||
Config struct {
|
||||
Link string
|
||||
Mode Mode
|
||||
Weight weightType
|
||||
}
|
||||
LoadBalancer struct {
|
||||
impl
|
||||
Config
|
||||
|
||||
pool servers
|
||||
poolMu sync.RWMutex
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
|
||||
sumWeight weightType
|
||||
}
|
||||
|
||||
weightType uint16
|
||||
)
|
||||
|
||||
const maxWeight weightType = 100
|
||||
|
||||
func New(cfg Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{Config: cfg, pool: servers{}}
|
||||
mode := cfg.Mode
|
||||
if !cfg.Mode.ValidateUpdate() {
|
||||
logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode)
|
||||
}
|
||||
switch mode {
|
||||
case RoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case LeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case IPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
}
|
||||
return lb
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) AddServer(srv *Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.pool = append(lb.pool, srv)
|
||||
lb.sumWeight += srv.Weight
|
||||
|
||||
lb.impl.OnAddServer(srv)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
lb.poolMu.RLock()
|
||||
defer lb.poolMu.RUnlock()
|
||||
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
for i, s := range lb.pool {
|
||||
if s == srv {
|
||||
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if lb.IsEmpty() {
|
||||
lb.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) IsEmpty() bool {
|
||||
return len(lb.pool) == 0
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Rebalance() {
|
||||
if lb.sumWeight == maxWeight {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / weightType(len(lb.pool))
|
||||
remainer := maxWeight % weightType(len(lb.pool))
|
||||
for _, s := range lb.pool {
|
||||
s.Weight = weightEach
|
||||
lb.sumWeight += weightEach
|
||||
if remainer > 0 {
|
||||
s.Weight++
|
||||
}
|
||||
remainer--
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// scale evenly
|
||||
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
||||
lb.sumWeight = 0
|
||||
|
||||
for _, s := range lb.pool {
|
||||
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
||||
lb.sumWeight += s.Weight
|
||||
}
|
||||
|
||||
delta := maxWeight - lb.sumWeight
|
||||
if delta == 0 {
|
||||
return
|
||||
}
|
||||
for _, s := range lb.pool {
|
||||
if delta == 0 {
|
||||
break
|
||||
}
|
||||
if delta > 0 {
|
||||
s.Weight++
|
||||
lb.sumWeight++
|
||||
delta--
|
||||
} else {
|
||||
s.Weight--
|
||||
lb.sumWeight--
|
||||
delta++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
srvs := lb.availServers()
|
||||
if len(srvs) == 0 {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Start() {
|
||||
if lb.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
|
||||
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
|
||||
for _, s := range lb.pool {
|
||||
msg.Addf("%s: %d", s.Name, s.Weight)
|
||||
}
|
||||
lb.Rebalance()
|
||||
inner := E.NewBuilder("After rebalancing")
|
||||
for _, s := range lb.pool {
|
||||
inner.Addf("%s: %d", s.Name, s.Weight)
|
||||
}
|
||||
msg.Addf("%s", inner)
|
||||
logger.Warn(msg)
|
||||
}
|
||||
|
||||
if lb.sumWeight != 0 {
|
||||
log.Warnf("Weighted mode not supported yet")
|
||||
}
|
||||
|
||||
switch lb.Mode {
|
||||
case RoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case LeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case IPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
}
|
||||
|
||||
lb.done = make(chan struct{}, 1)
|
||||
lb.ctx, lb.cancel = context.WithCancel(context.Background())
|
||||
|
||||
updateAll := func() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(lb.pool))
|
||||
for _, s := range lb.pool {
|
||||
go func(s *Server) {
|
||||
defer wg.Done()
|
||||
s.checkUpdateAvail(lb.ctx)
|
||||
}(s)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer lb.cancel()
|
||||
defer close(lb.done)
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
updateAll()
|
||||
for {
|
||||
select {
|
||||
case <-lb.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
lb.poolMu.RLock()
|
||||
updateAll()
|
||||
lb.poolMu.RUnlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Stop() {
|
||||
if lb.impl == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lb.cancel()
|
||||
|
||||
<-lb.done
|
||||
lb.pool = nil
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) availServers() servers {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
avail := servers{}
|
||||
for _, s := range lb.pool {
|
||||
if s.available.Load() {
|
||||
avail = append(avail, s)
|
||||
}
|
||||
}
|
||||
return avail
|
||||
}
|
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRebalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
lb := New(Config{})
|
||||
for range 10 {
|
||||
lb.AddServer(&Server{})
|
||||
}
|
||||
lb.Rebalance()
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
lb := New(Config{})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.Rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("more", func(t *testing.T) {
|
||||
lb := New(Config{})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.Rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
}
|
5
internal/net/http/loadbalancer/logger.go
Normal file
5
internal/net/http/loadbalancer/logger.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package loadbalancer
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
|
||||
var logger = logrus.WithField("module", "load_balancer")
|
29
internal/net/http/loadbalancer/mode.go
Normal file
29
internal/net/http/loadbalancer/mode.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
RoundRobin Mode = "roundrobin"
|
||||
LeastConn Mode = "leastconn"
|
||||
IPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch U.ToLowerNoSnake(string(*mode)) {
|
||||
case "", string(RoundRobin):
|
||||
*mode = RoundRobin
|
||||
return true
|
||||
case string(LeastConn):
|
||||
*mode = LeastConn
|
||||
return true
|
||||
case string(IPHash):
|
||||
*mode = IPHash
|
||||
return true
|
||||
}
|
||||
*mode = RoundRobin
|
||||
return false
|
||||
}
|
22
internal/net/http/loadbalancer/round_robin.go
Normal file
22
internal/net/http/loadbalancer/round_robin.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type roundRobin struct {
|
||||
index atomic.Uint32
|
||||
}
|
||||
|
||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv *Server) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1)
|
||||
srvs[index%uint32(len(srvs))].handler.ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
lb.index.Store(0)
|
||||
}
|
||||
}
|
67
internal/net/http/loadbalancer/server.go
Normal file
67
internal/net/http/loadbalancer/server.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
type (
|
||||
Server struct {
|
||||
Name string
|
||||
URL types.URL
|
||||
Weight weightType
|
||||
handler http.Handler
|
||||
|
||||
pinger *http.Client
|
||||
available atomic.Bool
|
||||
}
|
||||
servers []*Server
|
||||
)
|
||||
|
||||
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
|
||||
srv := &Server{
|
||||
Name: name,
|
||||
URL: url,
|
||||
Weight: weightType(weight),
|
||||
handler: handler,
|
||||
pinger: &http.Client{Timeout: 3 * time.Second},
|
||||
}
|
||||
srv.available.Store(true)
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) checkUpdateAvail(ctx context.Context) {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodHead,
|
||||
srv.URL.String(),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("failed to create request: ", err)
|
||||
srv.available.Store(false)
|
||||
}
|
||||
|
||||
resp, err := srv.pinger.Do(req)
|
||||
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
|
||||
if !srv.available.Swap(true) {
|
||||
logger.Infof("server %s is up", srv.Name)
|
||||
}
|
||||
} else if err != nil {
|
||||
if srv.available.Swap(false) {
|
||||
logger.Warnf("server %s is down: %s", srv.Name, err)
|
||||
}
|
||||
} else {
|
||||
if srv.available.Swap(false) {
|
||||
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) String() string {
|
||||
return srv.Name
|
||||
}
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -29,7 +29,6 @@ func init() {
|
|||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"modifyresponse": ModifyResponse.m,
|
||||
"modifyrequest": ModifyRequest.m,
|
||||
"errorpage": CustomErrorPage,
|
||||
|
@ -37,6 +36,10 @@ func init() {
|
|||
"realip": RealIP.m,
|
||||
"cloudflarerealip": CloudflareRealIP.m,
|
||||
"cidrwhitelist": CIDRWhiteList.m,
|
||||
|
||||
// !experimental
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"oauth2": OAuth2.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range middlewares {
|
||||
|
|
129
internal/net/http/middleware/oauth2.go
Normal file
129
internal/net/http/middleware/oauth2.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type oAuth2 struct {
|
||||
*oAuth2Opts
|
||||
m *Middleware
|
||||
}
|
||||
|
||||
type oAuth2Opts struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string // Authorization Endpoint
|
||||
TokenURL string // Token Endpoint
|
||||
}
|
||||
|
||||
var OAuth2 = &oAuth2{
|
||||
m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||
}
|
||||
|
||||
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
oauth := new(oAuth2)
|
||||
oauth.m = &Middleware{
|
||||
impl: oauth,
|
||||
before: oauth.handleOAuth2,
|
||||
}
|
||||
oauth.oAuth2Opts = &oAuth2Opts{}
|
||||
err := Deserialize(opts, oauth.oAuth2Opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := E.NewBuilder("missing required fields")
|
||||
optV := reflect.ValueOf(oauth.oAuth2Opts)
|
||||
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
|
||||
if optV.FieldByName(field.Name).Len() == 0 {
|
||||
b.Add(E.Missing(field.Name))
|
||||
}
|
||||
}
|
||||
if b.HasError() {
|
||||
return nil, b.Build().Subject("oAuth2")
|
||||
}
|
||||
return oauth.m, nil
|
||||
}
|
||||
|
||||
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
||||
// Check if the user is authenticated (you may use session, cookie, etc.)
|
||||
if !userIsAuthenticated(r) {
|
||||
// TODO: Redirect to OAuth2 auth URL
|
||||
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
||||
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If you have a token in the query string, process it
|
||||
if code := r.URL.Query().Get("code"); code != "" {
|
||||
// Exchange the authorization code for a token here
|
||||
// Use the TokenURL and authenticate the user
|
||||
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
|
||||
if err != nil {
|
||||
// handle error
|
||||
http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Save token and user info based on your requirements
|
||||
saveToken(rw, token)
|
||||
|
||||
// Redirect to the originally requested URL
|
||||
http.Redirect(rw, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If user is authenticated, go to the next handler
|
||||
next(rw, r)
|
||||
}
|
||||
|
||||
func userIsAuthenticated(r *http.Request) bool {
|
||||
// Example: Check for a session or cookie
|
||||
session, err := r.Cookie("session_token")
|
||||
if err != nil || session.Value == "" {
|
||||
return false
|
||||
}
|
||||
// Validate the session_token if necessary
|
||||
return true
|
||||
}
|
||||
|
||||
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) {
|
||||
// Prepare the request body
|
||||
data := url.Values{
|
||||
"client_id": {opts.ClientID},
|
||||
"client_secret": {opts.ClientSecret},
|
||||
"code": {code},
|
||||
"grant_type": {"authorization_code"},
|
||||
"redirect_uri": {requestUri},
|
||||
}
|
||||
resp, err := http.PostForm(opts.TokenURL, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to request token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
||||
}
|
||||
// Decode the response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %v", err)
|
||||
}
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
func saveToken(rw ResponseWriter, token string) {
|
||||
// Example: Save token in cookie
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: "auth_token",
|
||||
Value: token,
|
||||
// set other properties as necessary, such as Secure and HttpOnly
|
||||
})
|
||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
//go:embed test_data/sample_headers.json
|
||||
|
@ -110,7 +111,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
|||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gphttp.NewReverseProxy(proxyURL, rr)
|
||||
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr)
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -86,7 +87,7 @@ type ReverseProxy struct {
|
|||
|
||||
ServeHTTP http.HandlerFunc
|
||||
|
||||
TargetURL *url.URL
|
||||
TargetURL types.URL
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
|
@ -144,7 +145,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
|||
// }
|
||||
//
|
||||
|
||||
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
if transport == nil {
|
||||
panic("nil transport")
|
||||
}
|
||||
|
@ -263,7 +264,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
rewriteRequestURL(outreq, p.TargetURL)
|
||||
rewriteRequestURL(outreq, p.TargetURL.URL)
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := UpgradeType(outreq.Header)
|
||||
|
@ -348,7 +349,6 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
roundTripMutex.Unlock()
|
||||
if err != nil {
|
||||
p.errorHandler(rw, outreq, err, false)
|
||||
errMsg := err.Error()
|
||||
res = &http.Response{
|
||||
Status: http.StatusText(http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
|
@ -358,7 +358,6 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||
Request: outreq,
|
||||
ContentLength: int64(len(errMsg)),
|
||||
TLS: outreq.TLS,
|
||||
}
|
||||
}
|
||||
|
|
24
internal/net/types/url.go
Normal file
24
internal/net/types/url.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package types
|
||||
|
||||
import "net/url"
|
||||
|
||||
type URL struct{ *url.URL }
|
||||
|
||||
func NewURL(url *url.URL) URL {
|
||||
return URL{url}
|
||||
}
|
||||
|
||||
func (u URL) String() string {
|
||||
if u.URL == nil {
|
||||
return "nil"
|
||||
}
|
||||
return u.URL.String()
|
||||
}
|
||||
|
||||
func (u URL) MarshalText() (text []byte, err error) {
|
||||
return []byte(u.String()), nil
|
||||
}
|
||||
|
||||
func (u URL) Equals(other URL) bool {
|
||||
return u.URL == other.URL || u.String() == other.String()
|
||||
}
|
|
@ -7,6 +7,8 @@ import (
|
|||
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
)
|
||||
|
@ -15,9 +17,10 @@ type (
|
|||
ReverseProxyEntry struct { // real model after validation
|
||||
Alias T.Alias
|
||||
Scheme T.Scheme
|
||||
URL *url.URL
|
||||
URL net.URL
|
||||
NoTLSVerify bool
|
||||
PathPatterns T.PathPatterns
|
||||
LoadBalance loadbalancer.Config
|
||||
Middlewares D.NestedLabelMap
|
||||
|
||||
/* Docker only */
|
||||
|
@ -47,6 +50,10 @@ func (rp *ReverseProxyEntry) IsDocker() bool {
|
|||
return rp.DockerHost != ""
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) IsZeroPort() bool {
|
||||
return rp.URL.Port() == "0"
|
||||
}
|
||||
|
||||
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
|
||||
m.FillMissingFields()
|
||||
|
||||
|
@ -107,9 +114,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
|
|||
return &ReverseProxyEntry{
|
||||
Alias: T.NewAlias(m.Alias),
|
||||
Scheme: s,
|
||||
URL: url,
|
||||
URL: net.NewURL(url),
|
||||
NoTLSVerify: m.NoTLSVerify,
|
||||
PathPatterns: pathPatterns,
|
||||
LoadBalance: m.LoadBalance,
|
||||
Middlewares: m.Middlewares,
|
||||
IdleTimeout: idleTimeout,
|
||||
WakeTimeout: wakeTimeout,
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"sync"
|
||||
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -13,6 +12,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
|
@ -21,16 +21,14 @@ import (
|
|||
|
||||
type (
|
||||
HTTPRoute struct {
|
||||
Alias PT.Alias `json:"alias"`
|
||||
TargetURL *URL `json:"target_url"`
|
||||
PathPatterns PT.PathPatterns `json:"path_patterns"`
|
||||
*P.ReverseProxyEntry
|
||||
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"`
|
||||
|
||||
entry *P.ReverseProxyEntry
|
||||
server *loadbalancer.Server
|
||||
handler http.Handler
|
||||
rp *ReverseProxy
|
||||
}
|
||||
|
||||
URL url.URL
|
||||
SubdomainKey = PT.Alias
|
||||
|
||||
ReverseProxyHandler struct {
|
||||
|
@ -80,10 +78,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
|||
defer httpRoutesMu.Unlock()
|
||||
|
||||
r := &HTTPRoute{
|
||||
Alias: entry.Alias,
|
||||
TargetURL: (*URL)(entry.URL),
|
||||
PathPatterns: entry.PathPatterns,
|
||||
entry: entry,
|
||||
ReverseProxyEntry: entry,
|
||||
rp: rp,
|
||||
}
|
||||
return r, nil
|
||||
|
@ -101,18 +96,19 @@ func (r *HTTPRoute) Start() E.NestedError {
|
|||
httpRoutesMu.Lock()
|
||||
defer httpRoutesMu.Unlock()
|
||||
|
||||
if r.entry.UseIdleWatcher() {
|
||||
watcher, err := idlewatcher.Register(r.entry)
|
||||
switch {
|
||||
case r.UseIdleWatcher():
|
||||
watcher, err := idlewatcher.Register(r.ReverseProxyEntry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.handler = idlewatcher.NewWaker(watcher, r.rp)
|
||||
} else if r.entry.URL.Port() == "0" ||
|
||||
r.entry.IsDocker() && !r.entry.ContainerRunning {
|
||||
case r.IsZeroPort() ||
|
||||
r.IsDocker() && !r.ContainerRunning:
|
||||
return nil
|
||||
} else if len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/" {
|
||||
case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
|
||||
r.handler = ReverseProxyHandler{r.rp}
|
||||
} else {
|
||||
default:
|
||||
mux := http.NewServeMux()
|
||||
for _, p := range r.PathPatterns {
|
||||
mux.HandleFunc(string(p), r.rp.ServeHTTP)
|
||||
|
@ -120,10 +116,28 @@ func (r *HTTPRoute) Start() E.NestedError {
|
|||
r.handler = mux
|
||||
}
|
||||
|
||||
if r.LoadBalance.Link == "" {
|
||||
httpRoutes.Store(string(r.Alias), r)
|
||||
return nil
|
||||
}
|
||||
|
||||
var lb *loadbalancer.LoadBalancer
|
||||
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
|
||||
if ok {
|
||||
lb = linked.LoadBalancer
|
||||
} else {
|
||||
lb = loadbalancer.New(r.LoadBalance)
|
||||
lb.Start()
|
||||
linked = &HTTPRoute{
|
||||
LoadBalancer: lb,
|
||||
handler: lb,
|
||||
}
|
||||
httpRoutes.Store(string(r.LoadBalance.Link), linked)
|
||||
}
|
||||
lb.AddServer(loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) Stop() (_ E.NestedError) {
|
||||
if r.handler == nil {
|
||||
return
|
||||
|
@ -137,7 +151,18 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
|
|||
}
|
||||
|
||||
r.handler = nil
|
||||
|
||||
if r.server != nil {
|
||||
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
|
||||
if ok {
|
||||
linked.LoadBalancer.RemoveServer(r.server)
|
||||
}
|
||||
if linked.LoadBalancer.IsEmpty() {
|
||||
httpRoutes.Delete(string(r.LoadBalance.Link))
|
||||
}
|
||||
} else {
|
||||
httpRoutes.Delete(string(r.Alias))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -145,14 +170,6 @@ func (r *HTTPRoute) Started() bool {
|
|||
return r.handler != nil
|
||||
}
|
||||
|
||||
func (u *URL) String() string {
|
||||
return (*url.URL)(u).String()
|
||||
}
|
||||
|
||||
func (u *URL) MarshalText() (text []byte, err error) {
|
||||
return []byte(u.String()), nil
|
||||
}
|
||||
|
||||
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
mux, err := findMuxFunc(r.Host)
|
||||
if err != nil {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
. "github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
H "github.com/yusing/go-proxy/internal/homepage"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
|
@ -21,8 +22,9 @@ type (
|
|||
Port string `yaml:"port" json:"port"`
|
||||
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only
|
||||
PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only
|
||||
LoadBalance loadbalancer.Config `yaml:"load_balance" json:"load_balance"`
|
||||
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"`
|
||||
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage"`
|
||||
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage,omitempty"`
|
||||
|
||||
/* Docker only */
|
||||
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`
|
||||
|
|
Loading…
Add table
Reference in a new issue