diff --git a/internal/config/config.go b/internal/config/config.go index 5682311..370ad25 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -212,7 +212,7 @@ func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedE continue } cfg.proxyProviders.Store(p.GetName(), p) - b.Add(p.LoadRoutes().Subject(dockerHost)) + b.Add(p.LoadRoutes().Subject(p.GetName())) } return } diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 90d23dc..9e00174 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -2,7 +2,6 @@ package idlewatcher import ( "context" - "crypto/tls" "fmt" "net/http" "strconv" @@ -19,10 +18,6 @@ type Waker struct { } func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker { - tr := &http.Transport{} - if w.NoTLSVerify { - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } orig := rp.ServeHTTP // workaround for stopped containers port become zero rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) { @@ -41,7 +36,7 @@ func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker { watcher: w, client: &http.Client{ Timeout: 1 * time.Second, - Transport: tr, + Transport: rp.Transport, }, rp: rp, } diff --git a/internal/error/builder.go b/internal/error/builder.go index 2950923..4528de7 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -65,6 +65,10 @@ func (b Builder) To(ptr *NestedError) { } } +func (b Builder) String() string { + return b.Build().String() +} + func (b Builder) HasError() bool { return len(b.errors) > 0 } diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go new file mode 100644 index 0000000..8223516 --- /dev/null +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -0,0 +1,33 @@ +package loadbalancer + +import ( + "hash/fnv" + "net" + "net/http" +) + +type ipHash struct{ *LoadBalancer } + +func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} } +func (ipHash) OnAddServer(srv *Server) {} +func (ipHash) OnRemoveServer(srv *Server) {} + +func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + http.Error(rw, "Internal error", http.StatusInternalServerError) + logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err) + return + } + idx := hashIP(ip) % uint32(len(impl.pool)) + if !impl.pool[idx].available.Load() { + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) + } + impl.pool[idx].handler.ServeHTTP(rw, r) +} + +func hashIP(ip string) uint32 { + h := fnv.New32a() + h.Write([]byte(ip)) + return h.Sum32() +} diff --git a/internal/net/http/loadbalancer/least_conn.go b/internal/net/http/loadbalancer/least_conn.go new file mode 100644 index 0000000..3c1b872 --- /dev/null +++ b/internal/net/http/loadbalancer/least_conn.go @@ -0,0 +1,53 @@ +package loadbalancer + +import ( + "net/http" + "sync/atomic" + + F "github.com/yusing/go-proxy/internal/utils/functional" +) + +type leastConn struct { + *LoadBalancer + nConn F.Map[*Server, *atomic.Int64] +} + +func (lb *LoadBalancer) newLeastConn() impl { + return &leastConn{ + LoadBalancer: lb, + nConn: F.NewMapOf[*Server, *atomic.Int64](), + } +} + +func (impl *leastConn) OnAddServer(srv *Server) { + impl.nConn.Store(srv, new(atomic.Int64)) +} + +func (impl *leastConn) OnRemoveServer(srv *Server) { + impl.nConn.Delete(srv) +} + +func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) { + srv := srvs[0] + minConn, ok := impl.nConn.Load(srv) + if !ok { + logger.Errorf("[BUG] server %s not found", srv.Name) + http.Error(rw, "Internal error", http.StatusInternalServerError) + } + + for i := 1; i < len(srvs); i++ { + nConn, ok := impl.nConn.Load(srvs[i]) + if !ok { + logger.Errorf("[BUG] server %s not found", srv.Name) + http.Error(rw, "Internal error", http.StatusInternalServerError) + } + if nConn.Load() < minConn.Load() { + minConn = nConn + srv = srvs[i] + } + } + + minConn.Add(1) + srv.handler.ServeHTTP(rw, r) + minConn.Add(-1) +} diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go new file mode 100644 index 0000000..4de47db --- /dev/null +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -0,0 +1,241 @@ +package loadbalancer + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/go-acme/lego/v4/log" + E "github.com/yusing/go-proxy/internal/error" +) + +// TODO: stats of each server +// TODO: support weighted mode +type ( + impl interface { + ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) + OnAddServer(srv *Server) + OnRemoveServer(srv *Server) + } + Config struct { + Link string + Mode Mode + Weight weightType + } + LoadBalancer struct { + impl + Config + + pool servers + poolMu sync.RWMutex + + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + sumWeight weightType + } + + weightType uint16 +) + +const maxWeight weightType = 100 + +func New(cfg Config) *LoadBalancer { + lb := &LoadBalancer{Config: cfg, pool: servers{}} + mode := cfg.Mode + if !cfg.Mode.ValidateUpdate() { + logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode) + } + switch mode { + case RoundRobin: + lb.impl = lb.newRoundRobin() + case LeastConn: + lb.impl = lb.newLeastConn() + case IPHash: + lb.impl = lb.newIPHash() + default: // should happen in test only + lb.impl = lb.newRoundRobin() + } + return lb +} + +func (lb *LoadBalancer) AddServer(srv *Server) { + lb.poolMu.Lock() + defer lb.poolMu.Unlock() + + lb.pool = append(lb.pool, srv) + lb.sumWeight += srv.Weight + + lb.impl.OnAddServer(srv) +} + +func (lb *LoadBalancer) RemoveServer(srv *Server) { + lb.poolMu.RLock() + defer lb.poolMu.RUnlock() + + lb.impl.OnRemoveServer(srv) + + for i, s := range lb.pool { + if s == srv { + lb.pool = append(lb.pool[:i], lb.pool[i+1:]...) + break + } + } + if lb.IsEmpty() { + lb.Stop() + } +} + +func (lb *LoadBalancer) IsEmpty() bool { + return len(lb.pool) == 0 +} + +func (lb *LoadBalancer) Rebalance() { + if lb.sumWeight == maxWeight { + return + } + if lb.sumWeight == 0 { // distribute evenly + weightEach := maxWeight / weightType(len(lb.pool)) + remainer := maxWeight % weightType(len(lb.pool)) + for _, s := range lb.pool { + s.Weight = weightEach + lb.sumWeight += weightEach + if remainer > 0 { + s.Weight++ + } + remainer-- + } + return + } + + // scale evenly + scaleFactor := float64(maxWeight) / float64(lb.sumWeight) + lb.sumWeight = 0 + + for _, s := range lb.pool { + s.Weight = weightType(float64(s.Weight) * scaleFactor) + lb.sumWeight += s.Weight + } + + delta := maxWeight - lb.sumWeight + if delta == 0 { + return + } + for _, s := range lb.pool { + if delta == 0 { + break + } + if delta > 0 { + s.Weight++ + lb.sumWeight++ + delta-- + } else { + s.Weight-- + lb.sumWeight-- + delta++ + } + } +} + +func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + srvs := lb.availServers() + if len(srvs) == 0 { + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) + return + } + lb.impl.ServeHTTP(srvs, rw, r) +} + +func (lb *LoadBalancer) Start() { + if lb.IsEmpty() { + return + } + + if lb.sumWeight != 0 && lb.sumWeight != maxWeight { + msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight) + for _, s := range lb.pool { + msg.Addf("%s: %d", s.Name, s.Weight) + } + lb.Rebalance() + inner := E.NewBuilder("After rebalancing") + for _, s := range lb.pool { + inner.Addf("%s: %d", s.Name, s.Weight) + } + msg.Addf("%s", inner) + logger.Warn(msg) + } + + if lb.sumWeight != 0 { + log.Warnf("Weighted mode not supported yet") + } + + switch lb.Mode { + case RoundRobin: + lb.impl = lb.newRoundRobin() + case LeastConn: + lb.impl = lb.newLeastConn() + case IPHash: + lb.impl = lb.newIPHash() + } + + lb.done = make(chan struct{}, 1) + lb.ctx, lb.cancel = context.WithCancel(context.Background()) + + updateAll := func() { + var wg sync.WaitGroup + wg.Add(len(lb.pool)) + for _, s := range lb.pool { + go func(s *Server) { + defer wg.Done() + s.checkUpdateAvail(lb.ctx) + }(s) + } + wg.Wait() + } + + go func() { + defer lb.cancel() + defer close(lb.done) + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + updateAll() + for { + select { + case <-lb.ctx.Done(): + return + case <-ticker.C: + lb.poolMu.RLock() + updateAll() + lb.poolMu.RUnlock() + } + } + }() +} + +func (lb *LoadBalancer) Stop() { + if lb.impl == nil { + return + } + + lb.cancel() + + <-lb.done + lb.pool = nil +} + +func (lb *LoadBalancer) availServers() servers { + lb.poolMu.Lock() + defer lb.poolMu.Unlock() + + avail := servers{} + for _, s := range lb.pool { + if s.available.Load() { + avail = append(avail, s) + } + } + return avail +} diff --git a/internal/net/http/loadbalancer/loadbalancer_test.go b/internal/net/http/loadbalancer/loadbalancer_test.go new file mode 100644 index 0000000..4b5f9ec --- /dev/null +++ b/internal/net/http/loadbalancer/loadbalancer_test.go @@ -0,0 +1,43 @@ +package loadbalancer + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestRebalance(t *testing.T) { + t.Parallel() + t.Run("zero", func(t *testing.T) { + lb := New(Config{}) + for range 10 { + lb.AddServer(&Server{}) + } + lb.Rebalance() + ExpectEqual(t, lb.sumWeight, maxWeight) + }) + t.Run("less", func(t *testing.T) { + lb := New(Config{}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) + lb.Rebalance() + // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) + ExpectEqual(t, lb.sumWeight, maxWeight) + }) + t.Run("more", func(t *testing.T) { + lb := New(Config{}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) + lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) + lb.Rebalance() + // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) + ExpectEqual(t, lb.sumWeight, maxWeight) + }) +} diff --git a/internal/net/http/loadbalancer/logger.go b/internal/net/http/loadbalancer/logger.go new file mode 100644 index 0000000..7b9b51d --- /dev/null +++ b/internal/net/http/loadbalancer/logger.go @@ -0,0 +1,5 @@ +package loadbalancer + +import "github.com/sirupsen/logrus" + +var logger = logrus.WithField("module", "load_balancer") diff --git a/internal/net/http/loadbalancer/mode.go b/internal/net/http/loadbalancer/mode.go new file mode 100644 index 0000000..9d6f91d --- /dev/null +++ b/internal/net/http/loadbalancer/mode.go @@ -0,0 +1,29 @@ +package loadbalancer + +import ( + U "github.com/yusing/go-proxy/internal/utils" +) + +type Mode string + +const ( + RoundRobin Mode = "roundrobin" + LeastConn Mode = "leastconn" + IPHash Mode = "iphash" +) + +func (mode *Mode) ValidateUpdate() bool { + switch U.ToLowerNoSnake(string(*mode)) { + case "", string(RoundRobin): + *mode = RoundRobin + return true + case string(LeastConn): + *mode = LeastConn + return true + case string(IPHash): + *mode = IPHash + return true + } + *mode = RoundRobin + return false +} diff --git a/internal/net/http/loadbalancer/round_robin.go b/internal/net/http/loadbalancer/round_robin.go new file mode 100644 index 0000000..3db994a --- /dev/null +++ b/internal/net/http/loadbalancer/round_robin.go @@ -0,0 +1,22 @@ +package loadbalancer + +import ( + "net/http" + "sync/atomic" +) + +type roundRobin struct { + index atomic.Uint32 +} + +func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} } +func (lb *roundRobin) OnAddServer(srv *Server) {} +func (lb *roundRobin) OnRemoveServer(srv *Server) {} + +func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) { + index := lb.index.Add(1) + srvs[index%uint32(len(srvs))].handler.ServeHTTP(rw, r) + if lb.index.Load() >= 2*uint32(len(srvs)) { + lb.index.Store(0) + } +} diff --git a/internal/net/http/loadbalancer/server.go b/internal/net/http/loadbalancer/server.go new file mode 100644 index 0000000..798693b --- /dev/null +++ b/internal/net/http/loadbalancer/server.go @@ -0,0 +1,67 @@ +package loadbalancer + +import ( + "context" + "net/http" + "sync/atomic" + "time" + + "github.com/yusing/go-proxy/internal/net/types" +) + +type ( + Server struct { + Name string + URL types.URL + Weight weightType + handler http.Handler + + pinger *http.Client + available atomic.Bool + } + servers []*Server +) + +func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server { + srv := &Server{ + Name: name, + URL: url, + Weight: weightType(weight), + handler: handler, + pinger: &http.Client{Timeout: 3 * time.Second}, + } + srv.available.Store(true) + return srv +} + +func (srv *Server) checkUpdateAvail(ctx context.Context) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodHead, + srv.URL.String(), + nil, + ) + if err != nil { + logger.Error("failed to create request: ", err) + srv.available.Store(false) + } + + resp, err := srv.pinger.Do(req) + if err == nil && resp.StatusCode != http.StatusServiceUnavailable { + if !srv.available.Swap(true) { + logger.Infof("server %s is up", srv.Name) + } + } else if err != nil { + if srv.available.Swap(false) { + logger.Warnf("server %s is down: %s", srv.Name, err) + } + } else { + if srv.available.Swap(false) { + logger.Warnf("server %s is down: status %s", srv.Name, resp.Status) + } + } +} + +func (srv *Server) String() string { + return srv.Name +} diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index 3a5cfe5..502d6f1 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -5,7 +5,7 @@ import ( "net/http" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/net/types" F "github.com/yusing/go-proxy/internal/utils/functional" ) diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index cd7f64c..7de7326 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -13,7 +13,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/net/types" ) const ( diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 60a1a84..26e67a7 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -29,7 +29,6 @@ func init() { "setxforwarded": SetXForwarded, "hidexforwarded": HideXForwarded, "redirecthttp": RedirectHTTP, - "forwardauth": ForwardAuth.m, "modifyresponse": ModifyResponse.m, "modifyrequest": ModifyRequest.m, "errorpage": CustomErrorPage, @@ -37,6 +36,10 @@ func init() { "realip": RealIP.m, "cloudflarerealip": CloudflareRealIP.m, "cidrwhitelist": CIDRWhiteList.m, + + // !experimental + "forwardauth": ForwardAuth.m, + "oauth2": OAuth2.m, } names := make(map[*Middleware][]string) for name, m := range middlewares { diff --git a/internal/net/http/middleware/oauth2.go b/internal/net/http/middleware/oauth2.go new file mode 100644 index 0000000..8b53804 --- /dev/null +++ b/internal/net/http/middleware/oauth2.go @@ -0,0 +1,129 @@ +package middleware + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "reflect" + + E "github.com/yusing/go-proxy/internal/error" +) + +type oAuth2 struct { + *oAuth2Opts + m *Middleware +} + +type oAuth2Opts struct { + ClientID string + ClientSecret string + AuthURL string // Authorization Endpoint + TokenURL string // Token Endpoint +} + +var OAuth2 = &oAuth2{ + m: &Middleware{withOptions: NewAuthentikOAuth2}, +} + +func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) { + oauth := new(oAuth2) + oauth.m = &Middleware{ + impl: oauth, + before: oauth.handleOAuth2, + } + oauth.oAuth2Opts = &oAuth2Opts{} + err := Deserialize(opts, oauth.oAuth2Opts) + if err != nil { + return nil, err + } + b := E.NewBuilder("missing required fields") + optV := reflect.ValueOf(oauth.oAuth2Opts) + for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) { + if optV.FieldByName(field.Name).Len() == 0 { + b.Add(E.Missing(field.Name)) + } + } + if b.HasError() { + return nil, b.Build().Subject("oAuth2") + } + return oauth.m, nil +} + +func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { + // Check if the user is authenticated (you may use session, cookie, etc.) + if !userIsAuthenticated(r) { + // TODO: Redirect to OAuth2 auth URL + http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", + oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) + return + } + + // If you have a token in the query string, process it + if code := r.URL.Query().Get("code"); code != "" { + // Exchange the authorization code for a token here + // Use the TokenURL and authenticate the user + token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI) + if err != nil { + // handle error + http.Error(rw, "failed to get token", http.StatusUnauthorized) + return + } + + // Save token and user info based on your requirements + saveToken(rw, token) + + // Redirect to the originally requested URL + http.Redirect(rw, r, "/", http.StatusFound) + return + } + + // If user is authenticated, go to the next handler + next(rw, r) +} + +func userIsAuthenticated(r *http.Request) bool { + // Example: Check for a session or cookie + session, err := r.Cookie("session_token") + if err != nil || session.Value == "" { + return false + } + // Validate the session_token if necessary + return true +} + +func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) { + // Prepare the request body + data := url.Values{ + "client_id": {opts.ClientID}, + "client_secret": {opts.ClientSecret}, + "code": {code}, + "grant_type": {"authorization_code"}, + "redirect_uri": {requestUri}, + } + resp, err := http.PostForm(opts.TokenURL, data) + if err != nil { + return "", fmt.Errorf("failed to request token: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) + } + // Decode the response + var tokenResp struct { + AccessToken string `json:"access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode token response: %v", err) + } + return tokenResp.AccessToken, nil +} + +func saveToken(rw ResponseWriter, token string) { + // Example: Save token in cookie + http.SetCookie(rw, &http.Cookie{ + Name: "auth_token", + Value: token, + // set other properties as necessary, such as Secure and HttpOnly + }) +} diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index ff5691f..f3ebc25 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -4,7 +4,7 @@ import ( "net" E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/net/types" ) // https://nginx.org/en/docs/http/ngx_http_realip_module.html diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index fa98a1d..f47c272 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index 18f90a1..47707c8 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -12,6 +12,7 @@ import ( "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/types" ) //go:embed test_data/sample_headers.json @@ -110,7 +111,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N } else { proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect } - rp := gphttp.NewReverseProxy(proxyURL, rr) + rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr) mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index cca5837..4371990 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -26,6 +26,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/net/http/httpguts" + "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" ) @@ -86,7 +87,7 @@ type ReverseProxy struct { ServeHTTP http.HandlerFunc - TargetURL *url.URL + TargetURL types.URL } func singleJoiningSlash(a, b string) string { @@ -144,7 +145,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // } // -func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy { +func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy { if transport == nil { panic("nil transport") } @@ -263,7 +264,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate } - rewriteRequestURL(outreq, p.TargetURL) + rewriteRequestURL(outreq, p.TargetURL.URL) outreq.Close = false reqUpType := UpgradeType(outreq.Header) @@ -348,18 +349,16 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { roundTripMutex.Unlock() if err != nil { p.errorHandler(rw, outreq, err, false) - errMsg := err.Error() res = &http.Response{ - Status: http.StatusText(http.StatusBadGateway), - StatusCode: http.StatusBadGateway, - Proto: outreq.Proto, - ProtoMajor: outreq.ProtoMajor, - ProtoMinor: outreq.ProtoMinor, - Header: make(http.Header), - Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))), - Request: outreq, - ContentLength: int64(len(errMsg)), - TLS: outreq.TLS, + Status: http.StatusText(http.StatusBadGateway), + StatusCode: http.StatusBadGateway, + Proto: outreq.Proto, + ProtoMajor: outreq.ProtoMajor, + ProtoMinor: outreq.ProtoMinor, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))), + Request: outreq, + TLS: outreq.TLS, } } diff --git a/internal/types/cidr.go b/internal/net/types/cidr.go similarity index 100% rename from internal/types/cidr.go rename to internal/net/types/cidr.go diff --git a/internal/net/types/url.go b/internal/net/types/url.go new file mode 100644 index 0000000..065b0ba --- /dev/null +++ b/internal/net/types/url.go @@ -0,0 +1,24 @@ +package types + +import "net/url" + +type URL struct{ *url.URL } + +func NewURL(url *url.URL) URL { + return URL{url} +} + +func (u URL) String() string { + if u.URL == nil { + return "nil" + } + return u.URL.String() +} + +func (u URL) MarshalText() (text []byte, err error) { + return []byte(u.String()), nil +} + +func (u URL) Equals(other URL) bool { + return u.URL == other.URL || u.String() == other.String() +} diff --git a/internal/proxy/entry.go b/internal/proxy/entry.go index d2b4a0b..845d09a 100644 --- a/internal/proxy/entry.go +++ b/internal/proxy/entry.go @@ -7,6 +7,8 @@ import ( D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer" + net "github.com/yusing/go-proxy/internal/net/types" T "github.com/yusing/go-proxy/internal/proxy/fields" "github.com/yusing/go-proxy/internal/types" ) @@ -15,9 +17,10 @@ type ( ReverseProxyEntry struct { // real model after validation Alias T.Alias Scheme T.Scheme - URL *url.URL + URL net.URL NoTLSVerify bool PathPatterns T.PathPatterns + LoadBalance loadbalancer.Config Middlewares D.NestedLabelMap /* Docker only */ @@ -47,6 +50,10 @@ func (rp *ReverseProxyEntry) IsDocker() bool { return rp.DockerHost != "" } +func (rp *ReverseProxyEntry) IsZeroPort() bool { + return rp.URL.Port() == "0" +} + func ValidateEntry(m *types.RawEntry) (any, E.NestedError) { m.FillMissingFields() @@ -107,9 +114,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn return &ReverseProxyEntry{ Alias: T.NewAlias(m.Alias), Scheme: s, - URL: url, + URL: net.NewURL(url), NoTLSVerify: m.NoTLSVerify, PathPatterns: pathPatterns, + LoadBalance: m.LoadBalance, Middlewares: m.Middlewares, IdleTimeout: idleTimeout, WakeTimeout: wakeTimeout, diff --git a/internal/route/http.go b/internal/route/http.go index 13fa463..61d0378 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -5,7 +5,6 @@ import ( "sync" "net/http" - "net/url" "strings" "github.com/sirupsen/logrus" @@ -13,6 +12,7 @@ import ( "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/middleware" P "github.com/yusing/go-proxy/internal/proxy" PT "github.com/yusing/go-proxy/internal/proxy/fields" @@ -21,16 +21,14 @@ import ( type ( HTTPRoute struct { - Alias PT.Alias `json:"alias"` - TargetURL *URL `json:"target_url"` - PathPatterns PT.PathPatterns `json:"path_patterns"` + *P.ReverseProxyEntry + LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"` - entry *P.ReverseProxyEntry + server *loadbalancer.Server handler http.Handler rp *ReverseProxy } - URL url.URL SubdomainKey = PT.Alias ReverseProxyHandler struct { @@ -80,11 +78,8 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { defer httpRoutesMu.Unlock() r := &HTTPRoute{ - Alias: entry.Alias, - TargetURL: (*URL)(entry.URL), - PathPatterns: entry.PathPatterns, - entry: entry, - rp: rp, + ReverseProxyEntry: entry, + rp: rp, } return r, nil } @@ -101,18 +96,19 @@ func (r *HTTPRoute) Start() E.NestedError { httpRoutesMu.Lock() defer httpRoutesMu.Unlock() - if r.entry.UseIdleWatcher() { - watcher, err := idlewatcher.Register(r.entry) + switch { + case r.UseIdleWatcher(): + watcher, err := idlewatcher.Register(r.ReverseProxyEntry) if err != nil { return err } r.handler = idlewatcher.NewWaker(watcher, r.rp) - } else if r.entry.URL.Port() == "0" || - r.entry.IsDocker() && !r.entry.ContainerRunning { + case r.IsZeroPort() || + r.IsDocker() && !r.ContainerRunning: return nil - } else if len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/" { + case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/": r.handler = ReverseProxyHandler{r.rp} - } else { + default: mux := http.NewServeMux() for _, p := range r.PathPatterns { mux.HandleFunc(string(p), r.rp.ServeHTTP) @@ -120,7 +116,25 @@ func (r *HTTPRoute) Start() E.NestedError { r.handler = mux } - httpRoutes.Store(string(r.Alias), r) + if r.LoadBalance.Link == "" { + httpRoutes.Store(string(r.Alias), r) + return nil + } + + var lb *loadbalancer.LoadBalancer + linked, ok := httpRoutes.Load(string(r.LoadBalance.Link)) + if ok { + lb = linked.LoadBalancer + } else { + lb = loadbalancer.New(r.LoadBalance) + lb.Start() + linked = &HTTPRoute{ + LoadBalancer: lb, + handler: lb, + } + httpRoutes.Store(string(r.LoadBalance.Link), linked) + } + lb.AddServer(loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler)) return nil } @@ -137,7 +151,18 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) { } r.handler = nil - httpRoutes.Delete(string(r.Alias)) + + if r.server != nil { + linked, ok := httpRoutes.Load(string(r.LoadBalance.Link)) + if ok { + linked.LoadBalancer.RemoveServer(r.server) + } + if linked.LoadBalancer.IsEmpty() { + httpRoutes.Delete(string(r.LoadBalance.Link)) + } + } else { + httpRoutes.Delete(string(r.Alias)) + } return } @@ -145,14 +170,6 @@ func (r *HTTPRoute) Started() bool { return r.handler != nil } -func (u *URL) String() string { - return (*url.URL)(u).String() -} - -func (u *URL) MarshalText() (text []byte, err error) { - return []byte(u.String()), nil -} - func ProxyHandler(w http.ResponseWriter, r *http.Request) { mux, err := findMuxFunc(r.Host) if err != nil { diff --git a/internal/types/raw_entry.go b/internal/types/raw_entry.go index 44c9e2d..1f4cf6b 100644 --- a/internal/types/raw_entry.go +++ b/internal/types/raw_entry.go @@ -8,6 +8,7 @@ import ( . "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" H "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/net/http/loadbalancer" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -15,14 +16,15 @@ type ( RawEntry struct { // raw entry object before validation // loaded from docker labels or yaml file - Alias string `yaml:"-" json:"-"` - Scheme string `yaml:"scheme" json:"scheme"` - Host string `yaml:"host" json:"host"` - Port string `yaml:"port" json:"port"` - NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only - PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only - Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"` - Homepage *H.HomePageItem `yaml:"homepage" json:"homepage"` + Alias string `yaml:"-" json:"-"` + Scheme string `yaml:"scheme" json:"scheme"` + Host string `yaml:"host" json:"host"` + Port string `yaml:"port" json:"port"` + NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only + PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only + LoadBalance loadbalancer.Config `yaml:"load_balance" json:"load_balance"` + Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"` + Homepage *H.HomePageItem `yaml:"homepage" json:"homepage,omitempty"` /* Docker only */ *D.ProxyProperties `yaml:"-" json:"proxy_properties"`