added round_robin, least_conn and ip_hash load balance support, small refactoring

This commit is contained in:
yusing 2024-10-09 10:39:07 +08:00
parent 1797896fa6
commit 5c40f4aa84
24 changed files with 739 additions and 64 deletions

View file

@ -212,7 +212,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(dockerHost))
b.Add(p.LoadRoutes().Subject(p.GetName()))
}
return
}

View file

@ -2,7 +2,6 @@ package idlewatcher
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strconv"
@ -19,10 +18,6 @@ type Waker struct {
}
func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
tr := &http.Transport{}
if w.NoTLSVerify {
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
orig := rp.ServeHTTP
// workaround for stopped containers port become zero
rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) {
@ -41,7 +36,7 @@ func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker {
watcher: w,
client: &http.Client{
Timeout: 1 * time.Second,
Transport: tr,
Transport: rp.Transport,
},
rp: rp,
}

View file

@ -65,6 +65,10 @@ func (b Builder) To(ptr *NestedError) {
}
}
func (b Builder) String() string {
return b.Build().String()
}
func (b Builder) HasError() bool {
return len(b.errors) > 0
}

View file

@ -0,0 +1,33 @@
package loadbalancer
import (
"hash/fnv"
"net"
"net/http"
)
type ipHash struct{ *LoadBalancer }
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
func (ipHash) OnAddServer(srv *Server) {}
func (ipHash) OnRemoveServer(srv *Server) {}
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
if !impl.pool[idx].available.Load() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
impl.pool[idx].handler.ServeHTTP(rw, r)
}
func hashIP(ip string) uint32 {
h := fnv.New32a()
h.Write([]byte(ip))
return h.Sum32()
}

View file

@ -0,0 +1,53 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type leastConn struct {
*LoadBalancer
nConn F.Map[*Server, *atomic.Int64]
}
func (lb *LoadBalancer) newLeastConn() impl {
return &leastConn{
LoadBalancer: lb,
nConn: F.NewMapOf[*Server, *atomic.Int64](),
}
}
func (impl *leastConn) OnAddServer(srv *Server) {
impl.nConn.Store(srv, new(atomic.Int64))
}
func (impl *leastConn) OnRemoveServer(srv *Server) {
impl.nConn.Delete(srv)
}
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
srv := srvs[0]
minConn, ok := impl.nConn.Load(srv)
if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i])
if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError)
}
if nConn.Load() < minConn.Load() {
minConn = nConn
srv = srvs[i]
}
}
minConn.Add(1)
srv.handler.ServeHTTP(rw, r)
minConn.Add(-1)
}

View file

@ -0,0 +1,241 @@
package loadbalancer
import (
"context"
"net/http"
"sync"
"time"
"github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
)
// TODO: stats of each server
// TODO: support weighted mode
type (
impl interface {
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
OnAddServer(srv *Server)
OnRemoveServer(srv *Server)
}
Config struct {
Link string
Mode Mode
Weight weightType
}
LoadBalancer struct {
impl
Config
pool servers
poolMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
done chan struct{}
sumWeight weightType
}
weightType uint16
)
const maxWeight weightType = 100
func New(cfg Config) *LoadBalancer {
lb := &LoadBalancer{Config: cfg, pool: servers{}}
mode := cfg.Mode
if !cfg.Mode.ValidateUpdate() {
logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode)
}
switch mode {
case RoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
lb.impl = lb.newLeastConn()
case IPHash:
lb.impl = lb.newIPHash()
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
return lb
}
func (lb *LoadBalancer) AddServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = append(lb.pool, srv)
lb.sumWeight += srv.Weight
lb.impl.OnAddServer(srv)
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.RLock()
defer lb.poolMu.RUnlock()
lb.impl.OnRemoveServer(srv)
for i, s := range lb.pool {
if s == srv {
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
break
}
}
if lb.IsEmpty() {
lb.Stop()
}
}
func (lb *LoadBalancer) IsEmpty() bool {
return len(lb.pool) == 0
}
func (lb *LoadBalancer) Rebalance() {
if lb.sumWeight == maxWeight {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(len(lb.pool))
remainer := maxWeight % weightType(len(lb.pool))
for _, s := range lb.pool {
s.Weight = weightEach
lb.sumWeight += weightEach
if remainer > 0 {
s.Weight++
}
remainer--
}
return
}
// scale evenly
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
for _, s := range lb.pool {
s.Weight = weightType(float64(s.Weight) * scaleFactor)
lb.sumWeight += s.Weight
}
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
for _, s := range lb.pool {
if delta == 0 {
break
}
if delta > 0 {
s.Weight++
lb.sumWeight++
delta--
} else {
s.Weight--
lb.sumWeight--
delta++
}
}
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
srvs := lb.availServers()
if len(srvs) == 0 {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
lb.impl.ServeHTTP(srvs, rw, r)
}
func (lb *LoadBalancer) Start() {
if lb.IsEmpty() {
return
}
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
for _, s := range lb.pool {
msg.Addf("%s: %d", s.Name, s.Weight)
}
lb.Rebalance()
inner := E.NewBuilder("After rebalancing")
for _, s := range lb.pool {
inner.Addf("%s: %d", s.Name, s.Weight)
}
msg.Addf("%s", inner)
logger.Warn(msg)
}
if lb.sumWeight != 0 {
log.Warnf("Weighted mode not supported yet")
}
switch lb.Mode {
case RoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
lb.impl = lb.newLeastConn()
case IPHash:
lb.impl = lb.newIPHash()
}
lb.done = make(chan struct{}, 1)
lb.ctx, lb.cancel = context.WithCancel(context.Background())
updateAll := func() {
var wg sync.WaitGroup
wg.Add(len(lb.pool))
for _, s := range lb.pool {
go func(s *Server) {
defer wg.Done()
s.checkUpdateAvail(lb.ctx)
}(s)
}
wg.Wait()
}
go func() {
defer lb.cancel()
defer close(lb.done)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
updateAll()
for {
select {
case <-lb.ctx.Done():
return
case <-ticker.C:
lb.poolMu.RLock()
updateAll()
lb.poolMu.RUnlock()
}
}
}()
}
func (lb *LoadBalancer) Stop() {
if lb.impl == nil {
return
}
lb.cancel()
<-lb.done
lb.pool = nil
}
func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
avail := servers{}
for _, s := range lb.pool {
if s.available.Load() {
avail = append(avail, s)
}
}
return avail
}

View file

@ -0,0 +1,43 @@
package loadbalancer
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestRebalance(t *testing.T) {
t.Parallel()
t.Run("zero", func(t *testing.T) {
lb := New(Config{})
for range 10 {
lb.AddServer(&Server{})
}
lb.Rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
lb := New(Config{})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("more", func(t *testing.T) {
lb := New(Config{})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
}

View file

@ -0,0 +1,5 @@
package loadbalancer
import "github.com/sirupsen/logrus"
var logger = logrus.WithField("module", "load_balancer")

View file

@ -0,0 +1,29 @@
package loadbalancer
import (
U "github.com/yusing/go-proxy/internal/utils"
)
type Mode string
const (
RoundRobin Mode = "roundrobin"
LeastConn Mode = "leastconn"
IPHash Mode = "iphash"
)
func (mode *Mode) ValidateUpdate() bool {
switch U.ToLowerNoSnake(string(*mode)) {
case "", string(RoundRobin):
*mode = RoundRobin
return true
case string(LeastConn):
*mode = LeastConn
return true
case string(IPHash):
*mode = IPHash
return true
}
*mode = RoundRobin
return false
}

View file

@ -0,0 +1,22 @@
package loadbalancer
import (
"net/http"
"sync/atomic"
)
type roundRobin struct {
index atomic.Uint32
}
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
func (lb *roundRobin) OnAddServer(srv *Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1)
srvs[index%uint32(len(srvs))].handler.ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0)
}
}

View file

@ -0,0 +1,67 @@
package loadbalancer
import (
"context"
"net/http"
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/net/types"
)
type (
Server struct {
Name string
URL types.URL
Weight weightType
handler http.Handler
pinger *http.Client
available atomic.Bool
}
servers []*Server
)
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
srv := &Server{
Name: name,
URL: url,
Weight: weightType(weight),
handler: handler,
pinger: &http.Client{Timeout: 3 * time.Second},
}
srv.available.Store(true)
return srv
}
func (srv *Server) checkUpdateAvail(ctx context.Context) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
srv.URL.String(),
nil,
)
if err != nil {
logger.Error("failed to create request: ", err)
srv.available.Store(false)
}
resp, err := srv.pinger.Do(req)
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
if !srv.available.Swap(true) {
logger.Infof("server %s is up", srv.Name)
}
} else if err != nil {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: %s", srv.Name, err)
}
} else {
if srv.available.Swap(false) {
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
}
}
}
func (srv *Server) String() string {
return srv.Name
}

View file

@ -5,7 +5,7 @@ import (
"net/http"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)

View file

@ -13,7 +13,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
)
const (

View file

@ -29,7 +29,6 @@ func init() {
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"forwardauth": ForwardAuth.m,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"errorpage": CustomErrorPage,
@ -37,6 +36,10 @@ func init() {
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
"cidrwhitelist": CIDRWhiteList.m,
// !experimental
"forwardauth": ForwardAuth.m,
"oauth2": OAuth2.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {

View file

@ -0,0 +1,129 @@
package middleware
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"reflect"
E "github.com/yusing/go-proxy/internal/error"
)
type oAuth2 struct {
*oAuth2Opts
m *Middleware
}
type oAuth2Opts struct {
ClientID string
ClientSecret string
AuthURL string // Authorization Endpoint
TokenURL string // Token Endpoint
}
var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2},
}
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
oauth := new(oAuth2)
oauth.m = &Middleware{
impl: oauth,
before: oauth.handleOAuth2,
}
oauth.oAuth2Opts = &oAuth2Opts{}
err := Deserialize(opts, oauth.oAuth2Opts)
if err != nil {
return nil, err
}
b := E.NewBuilder("missing required fields")
optV := reflect.ValueOf(oauth.oAuth2Opts)
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
if optV.FieldByName(field.Name).Len() == 0 {
b.Add(E.Missing(field.Name))
}
}
if b.HasError() {
return nil, b.Build().Subject("oAuth2")
}
return oauth.m, nil
}
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// Check if the user is authenticated (you may use session, cookie, etc.)
if !userIsAuthenticated(r) {
// TODO: Redirect to OAuth2 auth URL
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
return
}
// If you have a token in the query string, process it
if code := r.URL.Query().Get("code"); code != "" {
// Exchange the authorization code for a token here
// Use the TokenURL and authenticate the user
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
if err != nil {
// handle error
http.Error(rw, "failed to get token", http.StatusUnauthorized)
return
}
// Save token and user info based on your requirements
saveToken(rw, token)
// Redirect to the originally requested URL
http.Redirect(rw, r, "/", http.StatusFound)
return
}
// If user is authenticated, go to the next handler
next(rw, r)
}
func userIsAuthenticated(r *http.Request) bool {
// Example: Check for a session or cookie
session, err := r.Cookie("session_token")
if err != nil || session.Value == "" {
return false
}
// Validate the session_token if necessary
return true
}
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) {
// Prepare the request body
data := url.Values{
"client_id": {opts.ClientID},
"client_secret": {opts.ClientSecret},
"code": {code},
"grant_type": {"authorization_code"},
"redirect_uri": {requestUri},
}
resp, err := http.PostForm(opts.TokenURL, data)
if err != nil {
return "", fmt.Errorf("failed to request token: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
}
// Decode the response
var tokenResp struct {
AccessToken string `json:"access_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %v", err)
}
return tokenResp.AccessToken, nil
}
func saveToken(rw ResponseWriter, token string) {
// Example: Save token in cookie
http.SetCookie(rw, &http.Cookie{
Name: "auth_token",
Value: token,
// set other properties as necessary, such as Secure and HttpOnly
})
}

View file

@ -4,7 +4,7 @@ import (
"net"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html

View file

@ -6,7 +6,7 @@ import (
"strings"
"testing"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -12,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
//go:embed test_data/sample_headers.json
@ -110,7 +111,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gphttp.NewReverseProxy(proxyURL, rr)
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr

View file

@ -26,6 +26,7 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/net/http/httpguts"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
)
@ -86,7 +87,7 @@ type ReverseProxy struct {
ServeHTTP http.HandlerFunc
TargetURL *url.URL
TargetURL types.URL
}
func singleJoiningSlash(a, b string) string {
@ -144,7 +145,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
// }
//
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
if transport == nil {
panic("nil transport")
}
@ -263,7 +264,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
rewriteRequestURL(outreq, p.TargetURL)
rewriteRequestURL(outreq, p.TargetURL.URL)
outreq.Close = false
reqUpType := UpgradeType(outreq.Header)
@ -348,7 +349,6 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
roundTripMutex.Unlock()
if err != nil {
p.errorHandler(rw, outreq, err, false)
errMsg := err.Error()
res = &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
@ -358,7 +358,6 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
Request: outreq,
ContentLength: int64(len(errMsg)),
TLS: outreq.TLS,
}
}

24
internal/net/types/url.go Normal file
View file

@ -0,0 +1,24 @@
package types
import "net/url"
type URL struct{ *url.URL }
func NewURL(url *url.URL) URL {
return URL{url}
}
func (u URL) String() string {
if u.URL == nil {
return "nil"
}
return u.URL.String()
}
func (u URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u URL) Equals(other URL) bool {
return u.URL == other.URL || u.String() == other.String()
}

View file

@ -7,6 +7,8 @@ import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
)
@ -15,9 +17,10 @@ type (
ReverseProxyEntry struct { // real model after validation
Alias T.Alias
Scheme T.Scheme
URL *url.URL
URL net.URL
NoTLSVerify bool
PathPatterns T.PathPatterns
LoadBalance loadbalancer.Config
Middlewares D.NestedLabelMap
/* Docker only */
@ -47,6 +50,10 @@ func (rp *ReverseProxyEntry) IsDocker() bool {
return rp.DockerHost != ""
}
func (rp *ReverseProxyEntry) IsZeroPort() bool {
return rp.URL.Port() == "0"
}
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
m.FillMissingFields()
@ -107,9 +114,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
return &ReverseProxyEntry{
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: url,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
LoadBalance: m.LoadBalance,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,

View file

@ -5,7 +5,6 @@ import (
"sync"
"net/http"
"net/url"
"strings"
"github.com/sirupsen/logrus"
@ -13,6 +12,7 @@ import (
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
@ -21,16 +21,14 @@ import (
type (
HTTPRoute struct {
Alias PT.Alias `json:"alias"`
TargetURL *URL `json:"target_url"`
PathPatterns PT.PathPatterns `json:"path_patterns"`
*P.ReverseProxyEntry
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"`
entry *P.ReverseProxyEntry
server *loadbalancer.Server
handler http.Handler
rp *ReverseProxy
}
URL url.URL
SubdomainKey = PT.Alias
ReverseProxyHandler struct {
@ -80,10 +78,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
defer httpRoutesMu.Unlock()
r := &HTTPRoute{
Alias: entry.Alias,
TargetURL: (*URL)(entry.URL),
PathPatterns: entry.PathPatterns,
entry: entry,
ReverseProxyEntry: entry,
rp: rp,
}
return r, nil
@ -101,18 +96,19 @@ func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.entry.UseIdleWatcher() {
watcher, err := idlewatcher.Register(r.entry)
switch {
case r.UseIdleWatcher():
watcher, err := idlewatcher.Register(r.ReverseProxyEntry)
if err != nil {
return err
}
r.handler = idlewatcher.NewWaker(watcher, r.rp)
} else if r.entry.URL.Port() == "0" ||
r.entry.IsDocker() && !r.entry.ContainerRunning {
case r.IsZeroPort() ||
r.IsDocker() && !r.ContainerRunning:
return nil
} else if len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/" {
case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
r.handler = ReverseProxyHandler{r.rp}
} else {
default:
mux := http.NewServeMux()
for _, p := range r.PathPatterns {
mux.HandleFunc(string(p), r.rp.ServeHTTP)
@ -120,8 +116,26 @@ func (r *HTTPRoute) Start() E.NestedError {
r.handler = mux
}
if r.LoadBalance.Link == "" {
httpRoutes.Store(string(r.Alias), r)
return nil
}
var lb *loadbalancer.LoadBalancer
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
if ok {
lb = linked.LoadBalancer
} else {
lb = loadbalancer.New(r.LoadBalance)
lb.Start()
linked = &HTTPRoute{
LoadBalancer: lb,
handler: lb,
}
httpRoutes.Store(string(r.LoadBalance.Link), linked)
}
lb.AddServer(loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler))
return nil
}
func (r *HTTPRoute) Stop() (_ E.NestedError) {
@ -137,7 +151,18 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
}
r.handler = nil
if r.server != nil {
linked, ok := httpRoutes.Load(string(r.LoadBalance.Link))
if ok {
linked.LoadBalancer.RemoveServer(r.server)
}
if linked.LoadBalancer.IsEmpty() {
httpRoutes.Delete(string(r.LoadBalance.Link))
}
} else {
httpRoutes.Delete(string(r.Alias))
}
return
}
@ -145,14 +170,6 @@ func (r *HTTPRoute) Started() bool {
return r.handler != nil
}
func (u *URL) String() string {
return (*url.URL)(u).String()
}
func (u *URL) MarshalText() (text []byte, err error) {
return []byte(u.String()), nil
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
mux, err := findMuxFunc(r.Host)
if err != nil {

View file

@ -8,6 +8,7 @@ import (
. "github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
H "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@ -21,8 +22,9 @@ type (
Port string `yaml:"port" json:"port"`
NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only
PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only
LoadBalance loadbalancer.Config `yaml:"load_balance" json:"load_balance"`
Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"`
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage"`
Homepage *H.HomePageItem `yaml:"homepage" json:"homepage,omitempty"`
/* Docker only */
*D.ProxyProperties `yaml:"-" json:"proxy_properties"`