Fixed a few issues:

- Incorrect name being shown on dashboard "Proxies page"
- Apps being shown when homepage.show is false
- Load balanced routes are shown on homepage instead of the load balancer
- Route with idlewatcher will now be removed on container destroy
- Idlewatcher panic
- Performance improvement
- Idlewatcher infinitely loading
- Reload stucked / not working properly
- Streams stuck on shutdown / reload
- etc...
Added:
- support idlewatcher for loadbalanced routes
- partial implementation for stream type idlewatcher
Issues:
- graceful shutdown
This commit is contained in:
yusing 2024-10-18 16:47:01 +08:00
parent c0c61709ca
commit 53557e38b6
69 changed files with 2368 additions and 1654 deletions

View file

@ -30,6 +30,12 @@ get:
debug: debug:
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
debug-trace:
make build && sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
profile:
GODEBUG=gctrace=1 make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
mtrace: mtrace:
bin/go-proxy debug-ls-mtrace > mtrace.json bin/go-proxy debug-ls-mtrace > mtrace.json

View file

@ -20,6 +20,7 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
) )
@ -32,8 +33,14 @@ func main() {
} }
l := logrus.WithField("module", "main") l := logrus.WithField("module", "main")
timeFmt := "01-02 15:04:05"
fullTS := true
if common.IsDebug { if common.IsTrace {
logrus.SetLevel(logrus.TraceLevel)
timeFmt = "04:05"
fullTS = false
} else if common.IsDebug {
logrus.SetLevel(logrus.DebugLevel) logrus.SetLevel(logrus.DebugLevel)
} }
@ -42,9 +49,9 @@ func main() {
} else { } else {
logrus.SetFormatter(&logrus.TextFormatter{ logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true, DisableSorting: true,
FullTimestamp: true, FullTimestamp: fullTS,
ForceColors: true, ForceColors: true,
TimestampFormat: "01-02 15:04:05", TimestampFormat: timeFmt,
}) })
logrus.Infof("go-proxy version %s", pkg.GetVersion()) logrus.Infof("go-proxy version %s", pkg.GetVersion())
} }
@ -76,21 +83,22 @@ func main() {
middleware.LoadComposeFiles() middleware.LoadComposeFiles()
if err := config.Load(); err != nil { var cfg *config.Config
var err E.NestedError
if cfg, err = config.Load(); err != nil {
logrus.Warn(err) logrus.Warn(err)
} }
cfg := config.GetInstance()
switch args.Command { switch args.Command {
case common.CommandListConfigs: case common.CommandListConfigs:
printJSON(cfg.Value()) printJSON(config.Value())
return return
case common.CommandListRoutes: case common.CommandListRoutes:
routes, err := query.ListRoutes() routes, err := query.ListRoutes()
if err != nil { if err != nil {
log.Printf("failed to connect to api server: %s", err) log.Printf("failed to connect to api server: %s", err)
log.Printf("falling back to config file") log.Printf("falling back to config file")
printJSON(cfg.RoutesByAlias()) printJSON(config.RoutesByAlias())
} else { } else {
printJSON(routes) printJSON(routes)
} }
@ -103,10 +111,10 @@ func main() {
printJSON(icons) printJSON(icons)
return return
case common.CommandDebugListEntries: case common.CommandDebugListEntries:
printJSON(cfg.DumpEntries()) printJSON(config.DumpEntries())
return return
case common.CommandDebugListProviders: case common.CommandDebugListProviders:
printJSON(cfg.DumpProviders()) printJSON(config.DumpProviders())
return return
case common.CommandDebugListMTrace: case common.CommandDebugListMTrace:
trace, err := query.ListMiddlewareTraces() trace, err := query.ListMiddlewareTraces()
@ -114,17 +122,25 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
printJSON(trace) printJSON(trace)
return
case common.CommandDebugListTasks:
tasks, err := query.DebugListTasks()
if err != nil {
log.Fatal(err)
}
printJSON(tasks)
return
} }
cfg.StartProxyProviders() cfg.StartProxyProviders()
cfg.WatchChanges() config.WatchChanges()
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGINT)
signal.Notify(sig, syscall.SIGTERM) signal.Notify(sig, syscall.SIGTERM)
signal.Notify(sig, syscall.SIGHUP) signal.Notify(sig, syscall.SIGHUP)
autocert := cfg.GetAutoCertProvider() autocert := config.GetAutoCertProvider()
if autocert != nil { if autocert != nil {
if err := autocert.Setup(); err != nil { if err := autocert.Setup(); err != nil {
l.Fatal(err) l.Fatal(err)
@ -139,14 +155,14 @@ func main() {
HTTPAddr: common.ProxyHTTPAddr, HTTPAddr: common.ProxyHTTPAddr,
HTTPSAddr: common.ProxyHTTPSAddr, HTTPSAddr: common.ProxyHTTPSAddr,
Handler: http.HandlerFunc(R.ProxyHandler), Handler: http.HandlerFunc(R.ProxyHandler),
RedirectToHTTPS: cfg.Value().RedirectToHTTPS, RedirectToHTTPS: config.Value().RedirectToHTTPS,
}) })
apiServer := server.InitAPIServer(server.Options{ apiServer := server.InitAPIServer(server.Options{
Name: "api", Name: "api",
CertProvider: autocert, CertProvider: autocert,
HTTPAddr: common.APIHTTPAddr, HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(cfg), Handler: api.NewHandler(),
RedirectToHTTPS: cfg.Value().RedirectToHTTPS, RedirectToHTTPS: config.Value().RedirectToHTTPS,
}) })
proxyServer.Start() proxyServer.Start()
@ -157,8 +173,8 @@ func main() {
// grafully shutdown // grafully shutdown
logrus.Info("shutting down") logrus.Info("shutting down")
common.CancelGlobalContext() task.CancelGlobalContext()
common.GlobalContextWait(time.Second * time.Duration(cfg.Value().TimeoutShutdown)) task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
} }
func prepareDirectory(dir string) { func prepareDirectory(dir string) {

View file

@ -2,6 +2,7 @@ package api
import ( import (
"fmt" "fmt"
"net"
"net/http" "net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1" v1 "github.com/yusing/go-proxy/internal/api/v1"
@ -21,34 +22,35 @@ func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc
mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler)) mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler))
} }
func NewHandler(cfg *config.Config) http.Handler { func NewHandler() http.Handler {
mux := NewServeMux() mux := NewServeMux()
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth)) mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth)
mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth)) mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth)
mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload)) mux.HandleFunc("POST", "/v1/reload", v1.Reload)
mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List)) mux.HandleFunc("GET", "/v1/list", v1.List)
mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List)) mux.HandleFunc("GET", "/v1/list/{what}", v1.List)
mux.HandleFunc("GET", "/v1/file", v1.GetFileContent) mux.HandleFunc("GET", "/v1/file", v1.GetFileContent)
mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent) mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent)
mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/stats", v1.Stats)
mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS)) mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc()) mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
return mux return mux
} }
// allow only requests to API server with host matching common.APIHTTPAddr. // allow only requests to API server with localhost.
func checkHost(f http.HandlerFunc) http.HandlerFunc { func checkHost(f http.HandlerFunc) http.HandlerFunc {
if common.IsDebug { if common.IsDebug {
return f return f
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Host != common.APIHTTPAddr { host, _, _ := net.SplitHostPort(r.RemoteAddr)
Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr) if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
http.Error(w, "invalid request", http.StatusForbidden) Logger.Warnf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden)
return return
} }
f(w, r) f(w, r)

View file

@ -4,11 +4,10 @@ import (
"net/http" "net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils" . "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func CheckHealth(w http.ResponseWriter, r *http.Request) {
target := r.FormValue("target") target := r.FormValue("target")
if target == "" { if target == "" {
HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest) HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest)

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/provider" "github.com/yusing/go-proxy/internal/route/provider"
) )
func GetFileContent(w http.ResponseWriter, r *http.Request) { func GetFileContent(w http.ResponseWriter, r *http.Request) {

View file

@ -9,19 +9,21 @@ import (
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
const ( const (
ListRoutes = "routes" ListRoutes = "routes"
ListConfigFiles = "config_files" ListConfigFiles = "config_files"
ListMiddlewares = "middlewares" ListMiddlewares = "middlewares"
ListMiddlewareTrace = "middleware_trace" ListMiddlewareTraces = "middleware_trace"
ListMatchDomains = "match_domains" ListMatchDomains = "match_domains"
ListHomepageConfig = "homepage_config" ListHomepageConfig = "homepage_config"
ListTasks = "tasks"
) )
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func List(w http.ResponseWriter, r *http.Request) {
what := r.PathValue("what") what := r.PathValue("what")
if what == "" { if what == "" {
what = ListRoutes what = ListRoutes
@ -29,27 +31,24 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
switch what { switch what {
case ListRoutes: case ListRoutes:
listRoutes(cfg, w, r) U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type"))))
case ListConfigFiles: case ListConfigFiles:
listConfigFiles(w, r) listConfigFiles(w, r)
case ListMiddlewares: case ListMiddlewares:
listMiddlewares(w, r) U.RespondJSON(w, r, middleware.All())
case ListMiddlewareTrace: case ListMiddlewareTraces:
listMiddlewareTrace(w, r) U.RespondJSON(w, r, middleware.GetAllTrace())
case ListMatchDomains: case ListMatchDomains:
listMatchDomains(cfg, w, r) U.RespondJSON(w, r, config.Value().MatchDomains)
case ListHomepageConfig: case ListHomepageConfig:
listHomepageConfig(cfg, w, r) U.RespondJSON(w, r, config.HomepageConfig())
case ListTasks:
U.RespondJSON(w, r, task.DebugTaskMap())
default: default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
} }
} }
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type")))
U.RespondJSON(w, r, routes)
}
func listConfigFiles(w http.ResponseWriter, r *http.Request) { func listConfigFiles(w http.ResponseWriter, r *http.Request) {
files, err := utils.ListFiles(common.ConfigBasePath, 1) files, err := utils.ListFiles(common.ConfigBasePath, 1)
if err != nil { if err != nil {
@ -61,19 +60,3 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
} }
U.RespondJSON(w, r, files) U.RespondJSON(w, r, files)
} }
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, middleware.GetAllTrace())
}
func listMiddlewares(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, middleware.All())
}
func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, cfg.Value().MatchDomains)
}
func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, cfg.HomepageConfig())
}

View file

@ -34,36 +34,34 @@ func ReloadServer() E.NestedError {
return nil return nil
} }
func ListRoutes() (map[string]map[string]any, E.NestedError) { func List[T any](what string) (_ T, outErr E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes)) resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
if err != nil { if err != nil {
return nil, E.From(err) outErr = E.From(err)
return
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list routes").Extraf("status code: %v", resp.StatusCode) outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode)
return
} }
var routes map[string]map[string]any var res T
err = json.NewDecoder(resp.Body).Decode(&routes) err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil { if err != nil {
return nil, E.From(err) outErr = E.From(err)
return
} }
return routes, nil return res, nil
}
func ListRoutes() (map[string]map[string]any, E.NestedError) {
return List[map[string]map[string]any](v1.ListRoutes)
} }
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) { func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace)) return List[middleware.Traces](v1.ListMiddlewareTraces)
if err != nil { }
return nil, E.From(err)
} func DebugListTasks() (map[string]any, E.NestedError) {
defer resp.Body.Close() return List[map[string]any](v1.ListTasks)
if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
}
var traces middleware.Traces
err = json.NewDecoder(resp.Body).Decode(&traces)
if err != nil {
return nil, E.From(err)
}
return traces, nil
} }

View file

@ -7,8 +7,8 @@ import (
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
) )
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func Reload(w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err != nil { if err := config.Reload(); err != nil {
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError) U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
} else { } else {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)

View file

@ -14,19 +14,19 @@ import (
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func Stats(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats(cfg)) U.RespondJSON(w, r, getStats())
} }
func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func StatsWS(w http.ResponseWriter, r *http.Request) {
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
originPats := make([]string, len(cfg.Value().MatchDomains)+len(localAddresses)) originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
if len(originPats) == 0 { if len(originPats) == 0 {
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins") U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
originPats = []string{"*"} originPats = []string{"*"}
} else { } else {
for i, domain := range cfg.Value().MatchDomains { for i, domain := range config.Value().MatchDomains {
originPats[i] = "*." + domain originPats[i] = "*." + domain
} }
originPats = append(originPats, localAddresses...) originPats = append(originPats, localAddresses...)
@ -51,7 +51,7 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
stats := getStats(cfg) stats := getStats()
if err := wsjson.Write(ctx, conn, stats); err != nil { if err := wsjson.Write(ctx, conn, stats); err != nil {
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err) U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
return return
@ -59,9 +59,9 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
} }
} }
func getStats(cfg *config.Config) map[string]any { func getStats() map[string]any {
return map[string]any{ return map[string]any{
"proxies": cfg.Statistics(), "proxies": config.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()), "uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
} }
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types" "github.com/yusing/go-proxy/internal/config/types"
) )
type Config types.AutoCertConfig type Config types.AutoCertConfig

View file

@ -13,9 +13,9 @@ import (
"github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
) )
@ -140,23 +140,22 @@ func (p *Provider) ScheduleRenewal() {
if p.GetName() == ProviderLocal { if p.GetName() == ProviderLocal {
return return
} }
go func() {
ticker := time.NewTicker(5 * time.Second) task := task.GlobalTask("cert renew scheduler")
defer ticker.Stop() ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
task := common.NewTask("cert renew scheduler") defer task.Finish("cert renew scheduler stopped")
defer task.Finished() for {
select {
for { case <-task.Context().Done():
select { return
case <-task.Context().Done(): case <-ticker.C: // check every 5 seconds
return if err := p.renewIfNeeded(); err.HasError() {
case <-ticker.C: // check every 5 seconds logger.Warn(err)
if err := p.renewIfNeeded(); err.HasError() { }
logger.Warn(err)
} }
} }
} }()
} }
func (p *Provider) initClient() E.NestedError { func (p *Provider) initClient() E.NestedError {

View file

@ -17,7 +17,7 @@ func (p *Provider) Setup() (err E.NestedError) {
} }
} }
go p.ScheduleRenewal() p.ScheduleRenewal()
for _, expiry := range p.GetExpiries() { for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry) logger.Infof("certificate expire on %s", expiry)

View file

@ -22,6 +22,7 @@ const (
CommandDebugListEntries = "debug-ls-entries" CommandDebugListEntries = "debug-ls-entries"
CommandDebugListProviders = "debug-ls-providers" CommandDebugListProviders = "debug-ls-providers"
CommandDebugListMTrace = "debug-ls-mtrace" CommandDebugListMTrace = "debug-ls-mtrace"
CommandDebugListTasks = "debug-ls-tasks"
) )
var ValidCommands = []string{ var ValidCommands = []string{
@ -35,6 +36,7 @@ var ValidCommands = []string{
CommandDebugListEntries, CommandDebugListEntries,
CommandDebugListProviders, CommandDebugListProviders,
CommandDebugListMTrace, CommandDebugListMTrace,
CommandDebugListTasks,
} }
func GetArgs() Args { func GetArgs() Args {

View file

@ -43,7 +43,6 @@ const (
HealthCheckIntervalDefault = 5 * time.Second HealthCheckIntervalDefault = 5 * time.Second
HealthCheckTimeoutDefault = 5 * time.Second HealthCheckTimeoutDefault = 5 * time.Second
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "30s" WakeTimeoutDefault = "30s"
StopTimeoutDefault = "10s" StopTimeoutDefault = "10s"
StopMethodDefault = "stop" StopMethodDefault = "stop"

View file

@ -15,6 +15,7 @@ var (
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true) NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true)
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test") IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest) IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug
ProxyHTTPAddr, ProxyHTTPAddr,
ProxyHTTPHost, ProxyHTTPHost,

View file

@ -1,224 +0,0 @@
package common
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/sirupsen/logrus"
)
var (
globalCtx, globalCtxCancel = context.WithCancel(context.Background())
taskWg sync.WaitGroup
tasksMap = xsync.NewMapOf[*task, struct{}]()
)
type (
Task interface {
Name() string
Context() context.Context
Subtask(usageFmt string, args ...interface{}) Task
SubtaskWithCancel(usageFmt string, args ...interface{}) (Task, context.CancelFunc)
Finished()
}
task struct {
ctx context.Context
subtasks []*task
name string
finished bool
mu sync.Mutex
}
)
func (t *task) Name() string {
return t.name
}
// Context returns the context associated with the task. This context is
// canceled when the task is finished.
func (t *task) Context() context.Context {
return t.ctx
}
// Finished marks the task as finished and notifies the global wait group.
// Finished is thread-safe and idempotent.
func (t *task) Finished() {
t.mu.Lock()
defer t.mu.Unlock()
if t.finished {
return
}
t.finished = true
if _, ok := tasksMap.Load(t); ok {
taskWg.Done()
tasksMap.Delete(t)
}
logrus.Debugf("task %q finished", t.Name())
}
// Subtask returns a new subtask with the given name, derived from the receiver's context.
//
// The returned subtask is associated with the receiver's context and will be
// automatically registered and deregistered from the global task wait group.
//
// If the receiver's context is already canceled, the returned subtask will be
// canceled immediately.
//
// The returned subtask is safe for concurrent use.
func (t *task) Subtask(format string, args ...interface{}) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
t.mu.Lock()
defer t.mu.Unlock()
sub := newSubTask(t.ctx, format)
t.subtasks = append(t.subtasks, sub)
return sub
}
// SubtaskWithCancel returns a new subtask with the given name, derived from the receiver's context,
// and a cancel function. The returned subtask is associated with the receiver's context and will be
// automatically registered and deregistered from the global task wait group.
//
// If the receiver's context is already canceled, the returned subtask will be canceled immediately.
//
// The returned cancel function is safe for concurrent use, and can be used to cancel the returned
// subtask at any time.
func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
t.mu.Lock()
defer t.mu.Unlock()
ctx, cancel := context.WithCancel(t.ctx)
sub := newSubTask(ctx, format)
t.subtasks = append(t.subtasks, sub)
return sub, cancel
}
func (t *task) tree(prefix ...string) string {
var sb strings.Builder
var pre string
if len(prefix) > 0 {
pre = prefix[0]
}
sb.WriteString(pre)
sb.WriteString(t.Name() + "\n")
for _, sub := range t.subtasks {
if sub.finished {
continue
}
sb.WriteString(sub.tree(pre + " "))
}
return sb.String()
}
func newSubTask(ctx context.Context, name string) *task {
t := &task{
ctx: ctx,
name: name,
}
tasksMap.Store(t, struct{}{})
taskWg.Add(1)
logrus.Debugf("task %q started", name)
return t
}
// NewTask returns a new Task with the given name, derived from the global
// context.
//
// The returned Task is associated with the global context and will be
// automatically registered and deregistered from the global context's wait
// group.
//
// If the global context is already canceled, the returned Task will be
// canceled immediately.
//
// The returned Task is not safe for concurrent use.
func NewTask(format string, args ...interface{}) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return newSubTask(globalCtx, format)
}
// NewTaskWithCancel returns a new Task with the given name, derived from the
// global context, and a cancel function. The returned Task is associated with
// the global context and will be automatically registered and deregistered
// from the global task wait group.
//
// If the global context is already canceled, the returned Task will be
// canceled immediately.
//
// The returned Task is safe for concurrent use.
//
// The returned cancel function is safe for concurrent use, and can be used
// to cancel the returned Task at any time.
func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
subCtx, cancel := context.WithCancel(globalCtx)
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return newSubTask(subCtx, format), cancel
}
// GlobalTask returns a new Task with the given name, associated with the
// global context.
//
// Unlike NewTask, GlobalTask does not automatically register or deregister
// the Task with the global task wait group. The returned Task is not
// started, but the name is formatted immediately.
//
// This is best used for main task that do not need to wait and
// will create a bunch of subtasks.
//
// The returned Task is safe for concurrent use.
func GlobalTask(format string, args ...interface{}) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return &task{
ctx: globalCtx,
name: format,
}
}
// CancelGlobalContext cancels the global context, which will cause all tasks
// created by GlobalTask or NewTask to be canceled. This should be called
// before exiting the program to ensure that all tasks are properly cleaned
// up.
func CancelGlobalContext() {
globalCtxCancel()
}
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) {
done := make(chan struct{})
after := time.After(timeout)
go func() {
taskWg.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
logrus.Warnln("Timeout waiting for these tasks to finish:")
tasksMap.Range(func(t *task, _ struct{}) bool {
logrus.Warnln(t.tree())
return true
})
return
}
}
}

View file

@ -2,51 +2,66 @@ package config
import ( import (
"os" "os"
"sync"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
PR "github.com/yusing/go-proxy/internal/proxy/provider" "github.com/yusing/go-proxy/internal/route"
R "github.com/yusing/go-proxy/internal/route" proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/types" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events" "github.com/yusing/go-proxy/internal/watcher/events"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type Config struct { type Config struct {
value *types.Config value *types.Config
proxyProviders F.Map[string, *PR.Provider] providers F.Map[string, *proxy.Provider]
autocertProvider *autocert.Provider autocertProvider *autocert.Provider
task task.Task
l logrus.FieldLogger
watcher W.Watcher
reloadReq chan struct{}
} }
var instance *Config var (
instance *Config
cfgWatcher watcher.Watcher
logger = logrus.WithField("module", "config")
reloadMu sync.Mutex
)
const configEventFlushInterval = 500 * time.Millisecond
const (
cfgRenameWarn = `Config file renamed, not reloading.
Make sure you rename it back before next time you start.`
cfgDeleteWarn = `Config file deleted, not reloading.
You may run "ls-config" to show or dump the current config.`
)
func GetInstance() *Config { func GetInstance() *Config {
return instance return instance
} }
func Load() E.NestedError { func newConfig() *Config {
return &Config{
value: types.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](),
task: task.GlobalTask("config"),
}
}
func Load() (*Config, E.NestedError) {
if instance != nil { if instance != nil {
return nil return instance, nil
} }
instance = &Config{ instance = newConfig()
value: types.DefaultConfig(), cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
proxyProviders: F.NewMapOf[string, *PR.Provider](), return instance, instance.load()
l: logrus.WithField("module", "config"),
watcher: W.NewConfigFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
return instance.load()
} }
func Validate(data []byte) E.NestedError { func Validate(data []byte) E.NestedError {
@ -54,87 +69,90 @@ func Validate(data []byte) E.NestedError {
} }
func MatchDomains() []string { func MatchDomains() []string {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return instance.value.MatchDomains return instance.value.MatchDomains
} }
func (cfg *Config) Value() types.Config { func WatchChanges() {
if cfg == nil { task := task.GlobalTask("Config watcher")
logrus.Panic("config has not been loaded, please check if there is any errors") eventQueue := events.NewEventQueue(
} task,
return *cfg.value configEventFlushInterval,
OnConfigChange,
func(err E.NestedError) {
logger.Error(err)
},
)
eventQueue.Start(cfgWatcher.Events(task.Context()))
} }
func (cfg *Config) GetAutoCertProvider() *autocert.Provider { func OnConfigChange(flushTask task.Task, ev []events.Event) {
if instance == nil { defer flushTask.Finish("config reload complete")
logrus.Panic("config has not been loaded, please check if there is any errors")
// no matter how many events during the interval
// just reload once and check the last event
switch ev[len(ev)-1].Action {
case events.ActionFileRenamed:
logger.Warn(cfgRenameWarn)
return
case events.ActionFileDeleted:
logger.Warn(cfgDeleteWarn)
return
}
if err := Reload(); err != nil {
logger.Error(err)
} }
return cfg.autocertProvider
} }
func (cfg *Config) Reload() (err E.NestedError) { func Reload() E.NestedError {
cfg.stopProviders() // avoid race between config change and API reload request
err = cfg.load() reloadMu.Lock()
cfg.StartProxyProviders() defer reloadMu.Unlock()
return
newCfg := newConfig()
err := newCfg.load()
if err != nil {
return err
}
// cancel all current subtasks -> wait
// -> replace config -> start new subtasks
instance.task.Finish("config changed")
instance.task.Wait()
*instance = *newCfg
instance.StartProxyProviders()
return nil
}
func Value() types.Config {
return *instance.value
}
func GetAutoCertProvider() *autocert.Provider {
return instance.autocertProvider
}
func (cfg *Config) Task() task.Task {
return cfg.task
} }
func (cfg *Config) StartProxyProviders() { func (cfg *Config) StartProxyProviders() {
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes) b := E.NewBuilder("errors starting providers")
} cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName())))
func (cfg *Config) WatchChanges() {
task := common.NewTask("Config watcher")
go func() {
defer task.Finished()
for {
select {
case <-task.Context().Done():
return
case <-cfg.reloadReq:
if err := cfg.Reload(); err != nil {
cfg.l.Error(err)
}
}
}
}()
go func() {
eventCh, errCh := cfg.watcher.Events(task.Context())
for {
select {
case <-task.Context().Done():
return
case event := <-eventCh:
if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed {
cfg.l.Error("config file deleted or renamed, ignoring...")
continue
} else {
cfg.reloadReq <- struct{}{}
}
case err := <-errCh:
cfg.l.Error(err)
continue
}
}
}()
}
func (cfg *Config) forEachRoute(do func(alias string, r *R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r *R.Route) {
do(a, r, p)
})
}) })
if b.HasError() {
logger.Error(b.Build())
}
} }
func (cfg *Config) load() (res E.NestedError) { func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config") b := E.NewBuilder("errors loading config")
defer b.To(&res) defer b.To(&res)
cfg.l.Debug("loading config") logger.Debug("loading config")
defer cfg.l.Debug("loaded config") defer logger.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath)) data, err := E.Check(os.ReadFile(common.ConfigPath))
if err != nil { if err != nil {
@ -160,7 +178,7 @@ func (cfg *Config) load() (res E.NestedError) {
b.Add(cfg.loadProviders(&model.Providers)) b.Add(cfg.loadProviders(&model.Providers))
cfg.value = model cfg.value = model
R.SetFindMuxDomains(model.MatchDomains) route.SetFindMuxDomains(model.MatchDomains)
return return
} }
@ -169,8 +187,8 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
return return
} }
cfg.l.Debug("initializing autocert") logger.Debug("initializing autocert")
defer cfg.l.Debug("initialized autocert") defer logger.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err != nil { if err != nil {
@ -179,48 +197,34 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
return return
} }
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedError) { func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) {
cfg.l.Debug("loading providers") subtask := cfg.task.Subtask("load providers")
defer cfg.l.Debug("loaded providers") defer subtask.Finish("done")
b := E.NewBuilder("errors loading providers") errs := E.NewBuilder("errors loading providers")
defer b.To(&res) results := E.NewBuilder("loaded providers")
defer errs.To(&outErr)
for _, filename := range providers.Files { for _, filename := range providers.Files {
p, err := PR.NewFileProvider(filename) p, err := proxy.NewFileProvider(filename)
if err != nil { if err != nil {
b.Add(err.Subject(filename)) errs.Add(err)
continue continue
} }
cfg.proxyProviders.Store(p.GetName(), p) cfg.providers.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(filename)) errs.Add(p.LoadRoutes().Subject(filename))
results.Addf("%d routes from %s", p.NumRoutes(), filename)
} }
for name, dockerHost := range providers.Docker { for name, dockerHost := range providers.Docker {
p, err := PR.NewDockerProvider(name, dockerHost) p, err := proxy.NewDockerProvider(name, dockerHost)
if err != nil { if err != nil {
b.Add(err.Subjectf("%s (%s)", name, dockerHost)) errs.Add(err.Subjectf("%s (%s)", name, dockerHost))
continue continue
} }
cfg.proxyProviders.Store(p.GetName(), p) cfg.providers.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(p.GetName())) errs.Add(p.LoadRoutes().Subject(p.GetName()))
results.Addf("%d routes from %s", p.NumRoutes(), name)
} }
logger.Info(results.Build())
return return
} }
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) {
if err := do(p); err != nil {
errors.Add(err.Subject(p))
}
})
if err := errors.Build(); err != nil {
cfg.l.Error(err)
}
}
func (cfg *Config) stopProviders() {
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
}

View file

@ -6,33 +6,35 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/homepage"
PR "github.com/yusing/go-proxy/internal/proxy/provider" "github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types" proxy "github.com/yusing/go-proxy/internal/route/provider"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
func (cfg *Config) DumpEntries() map[string]*types.RawEntry { func DumpEntries() map[string]*entry.RawEntry {
entries := make(map[string]*types.RawEntry) entries := make(map[string]*entry.RawEntry)
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { instance.providers.RangeAll(func(_ string, p *proxy.Provider) {
entries[alias] = r.Entry p.RangeRoutes(func(alias string, r *route.Route) {
entries[alias] = r.Entry
})
}) })
return entries return entries
} }
func (cfg *Config) DumpProviders() map[string]*PR.Provider { func DumpProviders() map[string]*proxy.Provider {
entries := make(map[string]*PR.Provider) entries := make(map[string]*proxy.Provider)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { instance.providers.RangeAll(func(name string, p *proxy.Provider) {
entries[name] = p entries[name] = p
}) })
return entries return entries
} }
func (cfg *Config) HomepageConfig() homepage.Config { func HomepageConfig() homepage.Config {
var proto, port string var proto, port string
domains := cfg.value.MatchDomains domains := instance.value.MatchDomains
cert, _ := cfg.autocertProvider.GetCert(nil) cert, _ := instance.autocertProvider.GetCert(nil)
if cert != nil { if cert != nil {
proto = "https" proto = "https"
port = common.ProxyHTTPSPort port = common.ProxyHTTPSPort
@ -42,9 +44,9 @@ func (cfg *Config) HomepageConfig() homepage.Config {
} }
hpCfg := homepage.NewHomePageConfig() hpCfg := homepage.NewHomePageConfig()
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
entry := r.Raw en := r.Raw
item := entry.Homepage item := en.Homepage
if item == nil { if item == nil {
item = new(homepage.Item) item = new(homepage.Item)
item.Show = true item.Show = true
@ -63,12 +65,12 @@ func (cfg *Config) HomepageConfig() homepage.Config {
) )
} }
if r.IsDocker() { if entry.IsDocker(r) {
if item.Category == "" { if item.Category == "" {
item.Category = "Docker" item.Category = "Docker"
} }
item.SourceType = string(PR.ProviderTypeDocker) item.SourceType = string(proxy.ProviderTypeDocker)
} else if r.UseLoadBalance() { } else if entry.UseLoadBalance(r) {
if item.Category == "" { if item.Category == "" {
item.Category = "Load-balanced" item.Category = "Load-balanced"
} }
@ -77,7 +79,7 @@ func (cfg *Config) HomepageConfig() homepage.Config {
if item.Category == "" { if item.Category == "" {
item.Category = "Others" item.Category = "Others"
} }
item.SourceType = string(PR.ProviderTypeFile) item.SourceType = string(proxy.ProviderTypeFile)
} }
if item.URL == "" { if item.URL == "" {
@ -85,26 +87,26 @@ func (cfg *Config) HomepageConfig() homepage.Config {
item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port) item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port)
} }
} }
item.AltURL = r.URL().String() item.AltURL = r.TargetURL().String()
hpCfg.Add(item) hpCfg.Add(item)
}) })
return hpCfg return hpCfg
} }
func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any { func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
routes := make(map[string]any) routes := make(map[string]any)
if len(typeFilter) == 0 || typeFilter[0] == "" { if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream} typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
} }
for _, t := range typeFilter { for _, t := range typeFilter {
switch t { switch t {
case R.RouteTypeReverseProxy: case route.RouteTypeReverseProxy:
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
routes[alias] = r routes[alias] = r
}) })
case R.RouteTypeStream: case route.RouteTypeStream:
R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) { route.GetStreamProxies().RangeAll(func(alias string, r *route.StreamRoute) {
routes[alias] = r routes[alias] = r
}) })
} }
@ -112,12 +114,12 @@ func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
return routes return routes
} }
func (cfg *Config) Statistics() map[string]any { func Statistics() map[string]any {
nTotalStreams := 0 nTotalStreams := 0
nTotalRPs := 0 nTotalRPs := 0
providerStats := make(map[string]PR.ProviderStats) providerStats := make(map[string]proxy.ProviderStats)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { instance.providers.RangeAll(func(name string, p *proxy.Provider) {
providerStats[name] = p.Statistics() providerStats[name] = p.Statistics()
}) })
@ -133,9 +135,9 @@ func (cfg *Config) Statistics() map[string]any {
} }
} }
func (cfg *Config) FindRoute(alias string) *R.Route { func FindRoute(alias string) *route.Route {
return F.MapFind(cfg.proxyProviders, return F.MapFind(instance.providers,
func(p *PR.Provider) (*R.Route, bool) { func(p *proxy.Provider) (*route.Route, bool) {
if route, ok := p.GetRoute(alias); ok { if route, ok := p.GetRoute(alias); ok {
return route, true return route, true
} }

View file

@ -0,0 +1,24 @@
package types
type (
Config struct {
Providers ProxyProviders `json:"providers" yaml:",flow"`
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
}
ProxyProviders struct {
Files []string `json:"include" yaml:"include"` // docker, file
Docker map[string]string `json:"docker" yaml:"docker"`
}
)
func DefaultConfig() *Config {
return &Config{
Providers: ProxyProviders{},
TimeoutShutdown: 3,
RedirectToHTTPS: false,
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
@ -36,22 +37,13 @@ var (
) )
func init() { func init() {
go func() { task.GlobalTask("close docker clients").OnComplete("", func() {
task := common.NewTask("close all docker client") clientMap.RangeAllParallel(func(_ string, c Client) {
defer task.Finished() if c.Connected() {
for { c.Client.Close()
select {
case <-task.Context().Done():
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()
}
})
clientMap.Clear()
return
} }
} })
}() })
} }
func (c *SharedClient) Connected() bool { func (c *SharedClient) Connected() bool {
@ -141,19 +133,10 @@ func ConnectClient(host string) (Client, E.NestedError) {
<-c.refCount.Zero() <-c.refCount.Zero()
clientMap.Delete(c.key) clientMap.Delete(c.key)
if c.Client != nil { if c.Connected() {
c.Client.Close() c.Client.Close()
c.Client = nil
c.l.Debugf("client closed") c.l.Debugf("client closed")
} }
}() }()
return c, nil return c, nil
} }
func CloseAllClients() {
clientMap.RangeAllParallel(func(_ string, c Client) {
c.Client.Close()
})
clientMap.Clear()
logger.Debug("closed all clients")
}

View file

@ -2,6 +2,7 @@ package docker
import ( import (
"context" "context"
"errors"
"time" "time"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
@ -16,10 +17,13 @@ type ClientInfo struct {
} }
var listOptions = container.ListOptions{ var listOptions = container.ListOptions{
// created|restarting|running|removing|paused|exited|dead
// Filters: filters.NewArgs( // Filters: filters.NewArgs(
// filters.Arg("health", "healthy"), // filters.Arg("status", "created"),
// filters.Arg("health", "none"), // filters.Arg("status", "restarting"),
// filters.Arg("health", "starting"), // filters.Arg("status", "running"),
// filters.Arg("status", "paused"),
// filters.Arg("status", "exited"),
// ), // ),
All: true, All: true,
} }
@ -31,7 +35,7 @@ func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedE
} }
defer dockerClient.Close() defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout"))
defer cancel() defer cancel()
var containers []types.Container var containers []types.Container

View file

@ -32,11 +32,11 @@ type (
IsExcluded bool `json:"is_excluded" yaml:"-"` IsExcluded bool `json:"is_excluded" yaml:"-"`
IsExplicit bool `json:"is_explicit" yaml:"-"` IsExplicit bool `json:"is_explicit" yaml:"-"`
IsDatabase bool `json:"is_database" yaml:"-"` IsDatabase bool `json:"is_database" yaml:"-"`
IdleTimeout string `json:"idle_timeout" yaml:"-"` IdleTimeout string `json:"idle_timeout,omitempty" yaml:"-"`
WakeTimeout string `json:"wake_timeout" yaml:"-"` WakeTimeout string `json:"wake_timeout,omitempty" yaml:"-"`
StopMethod string `json:"stop_method" yaml:"-"` StopMethod string `json:"stop_method,omitempty" yaml:"-"`
StopTimeout string `json:"stop_timeout" yaml:"-"` // stop_method = "stop" only StopTimeout string `json:"stop_timeout,omitempty" yaml:"-"` // stop_method = "stop" only
StopSignal string `json:"stop_signal" yaml:"-"` // stop_method = "stop" | "kill" only StopSignal string `json:"stop_signal,omitempty" yaml:"-"` // stop_method = "stop" | "kill" only
Running bool `json:"running" yaml:"-"` Running bool `json:"running" yaml:"-"`
} }
) )

View file

@ -0,0 +1,112 @@
package idlewatcher
import (
"time"
"github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
type (
Config struct {
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
StopMethod StopMethod `json:"stop_method,omitempty"`
StopSignal Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running,omitempty"`
}
StopMethod string
Signal string
)
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
if cont == nil {
return nil, nil
}
if cont.IdleTimeout == "" {
return &Config{
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}, nil
}
b := E.NewBuilder("invalid idlewatcher config")
defer b.To(&res)
idleTimeout, err := validateDurationPostitive(cont.IdleTimeout)
b.Add(err.Subjectf("%s", "idle_timeout"))
wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout)
b.Add(err.Subjectf("%s", "wake_timeout"))
stopTimeout, err := validateDurationPostitive(cont.StopTimeout)
b.Add(err.Subjectf("%s", "stop_timeout"))
stopMethod, err := validateStopMethod(cont.StopMethod)
b.Add(err)
signal, err := validateSignal(cont.StopSignal)
b.Add(err)
if err := b.Build(); err != nil {
return
}
return &Config{
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopTimeout: int(stopTimeout.Seconds()),
StopMethod: stopMethod,
StopSignal: signal,
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}, nil
}
func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value).With(err)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}
func validateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}
func validateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View file

@ -20,16 +20,15 @@ var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(lo
const headerCheckRedirect = "X-Goproxy-Check-Redirect" const headerCheckRedirect = "X-Goproxy-Check-Redirect"
func (w *Watcher) makeRespBody(format string, args ...any) []byte { func (w *Watcher) makeLoadingPageBody() []byte {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf("%s is starting...", w.ContainerName)
data := new(templateData) data := new(templateData)
data.CheckRedirectHeader = headerCheckRedirect data.CheckRedirectHeader = headerCheckRedirect
data.Title = w.ContainerName data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, "\n", "<br>") data.Message = strings.ReplaceAll(msg, " ", "&ensp;")
data.Message = strings.ReplaceAll(data.Message, " ", "&ensp;")
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect)))
err := loadingPageTmpl.Execute(buf, data) err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen in production if err != nil { // should never happen in production
panic(err) panic(err)

View file

@ -1,197 +1,133 @@
package idlewatcher package idlewatcher
import ( import (
"context"
"net/http" "net/http"
"strconv" "sync/atomic"
"time" "time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
type Waker struct { type Waker interface {
*Watcher health.HealthMonitor
http.Handler
net.Stream
}
type waker struct {
_ U.NoCopy
client *http.Client
rp *gphttp.ReverseProxy rp *gphttp.ReverseProxy
stream net.Stream
hc health.HealthChecker
ready atomic.Bool
} }
func NewWaker(w *Watcher, rp *gphttp.ReverseProxy) *Waker { const (
return &Waker{ idleWakerCheckInterval = 100 * time.Millisecond
Watcher: w, idleWakerCheckTimeout = time.Second
client: &http.Client{ )
Timeout: 1 * time.Second,
Transport: rp.Transport, // TODO: support stream
},
rp: rp, func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) {
hcCfg := entry.HealthCheckConfig()
hcCfg.Timeout = idleWakerCheckTimeout
waker := &waker{
rp: rp,
stream: stream,
} }
}
func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) { watcher, err := registerWatcher(providerSubTask, entry, waker)
shouldNext := w.wake(rw, r) if err != nil {
if !shouldNext { return nil, err
return
} }
w.rp.ServeHTTP(rw, r)
if rp != nil {
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
} else if stream != nil {
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
} else {
panic("both nil")
}
return watcher, nil
} }
/* HealthMonitor interface */ // lifetime should follow route provider
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) {
func (w *Waker) Start() {} return newWaker(providerSubTask, entry, rp, nil)
func (w *Waker) Stop() {
w.Unregister()
} }
func (w *Waker) UpdateConfig(config health.HealthCheckConfig) { func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) {
panic("use idlewatcher.Register instead") return newWaker(providerSubTask, entry, nil, stream)
} }
func (w *Waker) Name() string { // Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask task.Task) E.NestedError {
w.task.OnComplete("stop route", func() {
routeSubTask.Parent().Finish("watcher stopped")
})
return nil
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason string) {}
// Name implements health.HealthMonitor.
func (w *Watcher) Name() string {
return w.String() return w.String()
} }
func (w *Waker) String() string { // String implements health.HealthMonitor.
return string(w.Alias) func (w *Watcher) String() string {
return w.ContainerName
} }
func (w *Waker) Status() health.Status { // Uptime implements health.HealthMonitor.
if w.ready.Load() { func (w *Watcher) Uptime() time.Duration {
return health.StatusHealthy
}
if !w.ContainerRunning {
return health.StatusNapping
}
return health.StatusStarting
}
func (w *Waker) Uptime() time.Duration {
return 0 return 0
} }
func (w *Waker) MarshalJSON() ([]byte, error) { // Status implements health.HealthMonitor.
var url types.URL func (w *Watcher) Status() health.Status {
if w.URL.String() != "http://:0" { if !w.ContainerRunning {
url = w.URL return health.StatusNapping
}
if w.ready.Load() {
return health.StatusHealthy
}
healthy, _, err := w.hc.CheckHealth()
switch {
case err != nil:
return health.StatusError
case healthy:
w.ready.Store(true)
return health.StatusHealthy
default:
return health.StatusStarting
}
}
// MarshalJSON implements health.HealthMonitor.
func (w *Watcher) MarshalJSON() ([]byte, error) {
var url net.URL
if w.hc.URL().Port() != "0" {
url = w.hc.URL()
} }
return (&health.JSONRepresentation{ return (&health.JSONRepresentation{
Name: w.Name(), Name: w.Name(),
Status: w.Status(), Status: w.Status(),
Config: &health.HealthCheckConfig{ Config: w.hc.Config(),
Interval: w.IdleTimeout, URL: url,
Timeout: w.WakeTimeout,
},
URL: url,
}).MarshalJSON() }).MarshalJSON()
} }
/* End of HealthMonitor interface */
func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is ready
if w.ready.Load() {
return true
}
ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout)
defer cancel()
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeRespBody("%s waking up...", w.ContainerName)
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
rw.Header().Add("Cache-Control", "no-cache")
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Cache-Control", "must-revalidate")
if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err)
}
return
}
select {
case <-w.task.Context().Done():
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
w.l.Debug("wake signal received")
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return
}
// maybe another request came in while we were waiting for the wake
if w.ready.Load() {
if isCheckRedirect {
rw.WriteHeader(http.StatusOK)
return
}
return true
}
for {
select {
case <-w.task.Context().Done():
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
wakeReq, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
w.URL.String(),
nil,
)
if err != nil {
w.l.Errorf("new request err to %s: %s", r.URL, err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
wakeResp, err := w.client.Do(wakeReq)
if err == nil && wakeResp.StatusCode != http.StatusServiceUnavailable {
w.ready.Store(true)
w.l.Debug("awaken")
if isCheckRedirect {
rw.WriteHeader(http.StatusOK)
return
}
logrus.Infof("container %s is ready, passing through to %s", w.Alias, w.rp.TargetURL)
return true
}
// retry until the container is ready or timeout
time.Sleep(100 * time.Millisecond)
}
}
// static HealthMonitor interface check
func (w *Waker) _() health.HealthMonitor {
return w
}

View file

@ -0,0 +1,105 @@
package idlewatcher
import (
"context"
"errors"
"net/http"
"strconv"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// ServeHTTP implements http.Handler
func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
shouldNext := w.wakeFromHTTP(rw, r)
if !shouldNext {
return
}
w.rp.ServeHTTP(rw, r)
}
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is already ready
if w.ready.Load() {
return true
}
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeLoadingPageBody()
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
rw.Header().Add("Cache-Control", "no-cache")
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Cache-Control", "must-revalidate")
rw.Header().Add("Connection", "close")
if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err)
}
return
}
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
checkCancelled := func() bool {
select {
case <-w.task.Context().Done():
w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context()))
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
case <-ctx.Done():
w.l.Debugf("wake cancelled: %s", context.Cause(ctx))
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return true
default:
return false
}
}
if checkCancelled() {
return false
}
w.l.Debug("wake signal received")
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return
}
for {
if checkCancelled() {
return false
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
if isCheckRedirect {
logrus.Debugf("container %s is ready, redirecting...", w.String())
rw.WriteHeader(http.StatusOK)
return
}
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
return true
}
// retry until the container is ready or timeout
time.Sleep(idleWakerCheckInterval)
}
}

View file

@ -0,0 +1,87 @@
package idlewatcher
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// Setup implements types.Stream.
func (w *Watcher) Setup() error {
return w.stream.Setup()
}
// Accept implements types.Stream.
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
conn, err = w.stream.Accept()
// timeout means no connection is accepted
var nErr *net.OpError
ok := errors.As(err, &nErr)
if ok && nErr.Timeout() {
return
}
if err := w.wakeFromStream(); err != nil {
return nil, err
}
return w.stream.Accept()
}
// CloseListeners implements types.Stream.
func (w *Watcher) CloseListeners() {
w.stream.CloseListeners()
}
// Handle implements types.Stream.
func (w *Watcher) Handle(conn types.StreamConn) error {
if err := w.wakeFromStream(); err != nil {
return err
}
return w.stream.Handle(conn)
}
func (w *Watcher) wakeFromStream() error {
// pass through if container is already ready
if w.ready.Load() {
return nil
}
w.l.Debug("wake signal received")
wakeErr := w.wakeIfStopped()
if wakeErr != nil {
wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr)
w.l.Error(wakeErr)
return wakeErr
}
ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
for {
select {
case <-w.task.Context().Done():
cause := w.task.FinishCause()
w.l.Debugf("wake cancelled: %s", cause)
return cause
case <-ctx.Done():
cause := context.Cause(ctx)
w.l.Debugf("wake cancelled: %s", cause)
return cause
default:
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
return nil
}
// retry until the container is ready or timeout
time.Sleep(idleWakerCheckInterval)
}
}

View file

@ -2,191 +2,193 @@ package idlewatcher
import ( import (
"context" "context"
"errors"
"fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy" "github.com/yusing/go-proxy/internal/proxy/entry"
PT "github.com/yusing/go-proxy/internal/proxy/fields" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
) )
type ( type (
Watcher struct { Watcher struct {
*P.ReverseProxyEntry _ U.NoCopy
client D.Client *idlewatcher.Config
*waker
ready atomic.Bool // whether the site is ready to accept connection client D.Client
stopByMethod StopCallback // send a docker command w.r.t. `stop_method` stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
ticker *time.Ticker
ticker *time.Ticker task task.Task
l *logrus.Entry
task common.Task
cancel context.CancelFunc
refCount *U.RefCount
l logrus.FieldLogger
} }
WakeDone <-chan error WakeDone <-chan error
WakeFunc func() WakeDone WakeFunc func() WakeDone
StopCallback func() E.NestedError StopCallback func() error
) )
var ( var (
watcherMap = F.NewMapOf[string, *Watcher]() watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex watcherMapMu sync.Mutex
portHistoryMap = F.NewMapOf[PT.Alias, string]()
logger = logrus.WithField("module", "idle_watcher") logger = logrus.WithField("module", "idle_watcher")
) )
func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) { const dockerReqTimeout = 3 * time.Second
failure := E.Failure("idle_watcher register")
if entry.IdleTimeout == 0 { func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) {
return nil, failure.With(E.Invalid("idle_timeout", 0)) failure := E.Failure("idle_watcher register")
cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 {
panic("should not reach here")
} }
watcherMapMu.Lock() watcherMapMu.Lock()
defer watcherMapMu.Unlock() defer watcherMapMu.Unlock()
key := entry.ContainerID key := cfg.ContainerID
if entry.URL.Port() != "0" {
portHistoryMap.Store(entry.Alias, entry.URL.Port())
}
if w, ok := watcherMap.Load(key); ok { if w, ok := watcherMap.Load(key); ok {
w.refCount.Add() w.Config = cfg
w.ReverseProxyEntry = entry w.waker = waker
w.resetIdleTimer()
return w, nil return w, nil
} }
client, err := D.ConnectClient(entry.DockerHost) client, err := D.ConnectClient(cfg.DockerHost)
if err.HasError() { if err.HasError() {
return nil, failure.With(err) return nil, failure.With(err)
} }
w := &Watcher{ w := &Watcher{
ReverseProxyEntry: entry, Config: cfg,
client: client, waker: waker,
refCount: U.NewRefCounter(), client: client,
ticker: time.NewTicker(entry.IdleTimeout), task: providerSubtask,
l: logger.WithField("container", entry.ContainerName), ticker: time.NewTicker(cfg.IdleTimeout),
l: logger.WithField("container", cfg.ContainerName),
} }
w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias)
w.stopByMethod = w.getStopCallback() w.stopByMethod = w.getStopCallback()
watcherMap.Store(key, w) watcherMap.Store(key, w)
go w.watchUntilCancel() go func() {
cause := w.watchUntilDestroy()
watcherMapMu.Lock()
watcherMap.Delete(w.ContainerID)
watcherMapMu.Unlock()
w.ticker.Stop()
w.client.Close()
w.task.Finish(cause.Error())
}()
return w, nil return w, nil
} }
func (w *Watcher) Unregister() { func (w *Watcher) containerStop(ctx context.Context) error {
w.refCount.Sub() return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
}
func (w *Watcher) containerStop() error {
return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal), Signal: string(w.StopSignal),
Timeout: &w.StopTimeout, Timeout: &w.StopTimeout,
}) })
} }
func (w *Watcher) containerPause() error { func (w *Watcher) containerPause(ctx context.Context) error {
return w.client.ContainerPause(w.task.Context(), w.ContainerID) return w.client.ContainerPause(ctx, w.ContainerID)
} }
func (w *Watcher) containerKill() error { func (w *Watcher) containerKill(ctx context.Context) error {
return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal)) return w.client.ContainerKill(ctx, w.ContainerID, string(w.StopSignal))
} }
func (w *Watcher) containerUnpause() error { func (w *Watcher) containerUnpause(ctx context.Context) error {
return w.client.ContainerUnpause(w.task.Context(), w.ContainerID) return w.client.ContainerUnpause(ctx, w.ContainerID)
} }
func (w *Watcher) containerStart() error { func (w *Watcher) containerStart(ctx context.Context) error {
return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{}) return w.client.ContainerStart(ctx, w.ContainerID, container.StartOptions{})
} }
func (w *Watcher) containerStatus() (string, E.NestedError) { func (w *Watcher) containerStatus() (string, error) {
if !w.client.Connected() { if !w.client.Connected() {
return "", E.Failure("docker client closed") return "", errors.New("docker client not connected")
} }
json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID) ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout"))
defer cancel()
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
if err != nil { if err != nil {
return "", E.FailWith("inspect container", err) return "", fmt.Errorf("failed to inspect container: %w", err)
} }
return json.State.Status, nil return json.State.Status, nil
} }
func (w *Watcher) wakeIfStopped() E.NestedError { func (w *Watcher) wakeIfStopped() error {
if w.ready.Load() || w.ContainerRunning { if w.ContainerRunning {
return nil return nil
} }
status, err := w.containerStatus() status, err := w.containerStatus()
if err != nil {
if err.HasError() {
return err return err
} }
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
defer cancel()
// !Hard coded here since theres no constants from Docker API
switch status { switch status {
case "exited", "dead": case "exited", "dead":
return E.From(w.containerStart()) return w.containerStart(ctx)
case "paused": case "paused":
return E.From(w.containerUnpause()) return w.containerUnpause(ctx)
case "running": case "running":
return nil return nil
default: default:
return E.Unexpected("container state", status) panic("should not reach here")
} }
} }
func (w *Watcher) getStopCallback() StopCallback { func (w *Watcher) getStopCallback() StopCallback {
var cb func() error var cb func(context.Context) error
switch w.StopMethod { switch w.StopMethod {
case PT.StopMethodPause: case idlewatcher.StopMethodPause:
cb = w.containerPause cb = w.containerPause
case PT.StopMethodStop: case idlewatcher.StopMethodStop:
cb = w.containerStop cb = w.containerStop
case PT.StopMethodKill: case idlewatcher.StopMethodKill:
cb = w.containerKill cb = w.containerKill
default: default:
panic("should not reach here") panic("should not reach here")
} }
return func() E.NestedError { return func() error {
status, err := w.containerStatus() ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
if err.HasError() { defer cancel()
return err return cb(ctx)
}
if status != "running" {
return nil
}
return E.From(cb())
} }
} }
func (w *Watcher) resetIdleTimer() { func (w *Watcher) resetIdleTimer() {
w.l.Trace("reset idle timer")
w.ticker.Reset(w.IdleTimeout) w.ticker.Reset(w.IdleTimeout)
} }
func (w *Watcher) watchUntilCancel() { func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) {
dockerWatcher := W.NewDockerWatcherWithClient(w.client) eventTask = w.task.Subtask("watcher for %s", w.ContainerID)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{ eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter( Filters: W.NewDockerFilter(
W.DockerFilterContainer, W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID), W.DockerrFilterContainer(w.ContainerID),
@ -194,34 +196,47 @@ func (w *Watcher) watchUntilCancel() {
W.DockerFilterStop, W.DockerFilterStop,
W.DockerFilterDie, W.DockerFilterDie,
W.DockerFilterKill, W.DockerFilterKill,
W.DockerFilterDestroy,
W.DockerFilterPause, W.DockerFilterPause,
W.DockerFilterUnpause, W.DockerFilterUnpause,
), ),
}) })
return
}
defer func() { // watchUntilDestroy waits for the container to be created, started, or unpaused,
w.cancel() // and then reset the idle timer.
w.ticker.Stop() //
w.client.Close() // When the container is stopped, paused,
watcherMap.Delete(w.ContainerID) // or killed, the idle timer is stopped and the ContainerRunning flag is set to false.
w.task.Finished() //
}() // When the idle timer fires, the container is stopped according to the
// stop method.
//
// it exits only if the context is canceled, the container is destroyed,
// errors occured on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() error {
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
for { for {
select { select {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
w.l.Debug("stopped by context done") cause := context.Cause(w.task.Context())
return w.l.Debugf("watcher stopped by context done: %s", cause)
case <-w.refCount.Zero(): return cause
w.l.Debug("stopped by zero ref count")
return
case err := <-dockerEventErrCh: case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) { if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err)) w.l.Error(E.FailWith("docker watcher", err))
return return err.Error()
} }
case e := <-dockerEventCh: case e := <-dockerEventCh:
switch { switch {
case e.Action == events.ActionContainerDestroy:
w.ContainerRunning = false
w.ready.Store(false)
w.l.Info("watcher stopped by container destruction")
return errors.New("container destroyed")
// create / start / unpause // create / start / unpause
case e.Action.IsContainerWake(): case e.Action.IsContainerWake():
w.ContainerRunning = true w.ContainerRunning = true
@ -229,18 +244,31 @@ func (w *Watcher) watchUntilCancel() {
w.l.Info("container awaken") w.l.Info("container awaken")
case e.Action.IsContainerSleep(): // stop / pause / kil case e.Action.IsContainerSleep(): // stop / pause / kil
w.ContainerRunning = false w.ContainerRunning = false
w.ticker.Stop()
w.ready.Store(false) w.ready.Store(false)
w.ticker.Stop()
default: default:
w.l.Errorf("unexpected docker event: %s", e) w.l.Errorf("unexpected docker event: %s", e)
} }
// container name changed should also change the container id
if w.ContainerName != e.ActorName {
w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName)
w.ContainerName = e.ActorName
}
if w.ContainerID != e.ActorID {
w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID)
w.ContainerID = e.ActorID
// recreate event stream
eventTask.Finish("recreate event stream")
eventTask, dockerEventCh, dockerEventErrCh = w.getEventCh(dockerWatcher)
}
case <-w.ticker.C: case <-w.ticker.C:
w.l.Debug("idle timeout")
w.ticker.Stop() w.ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) { if w.ContainerRunning {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod)) if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
} else { w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err)
w.l.Info("stopped by idle timeout") } else {
w.l.Info("container stopped by idle timeout")
}
} }
} }
} }

View file

@ -2,6 +2,7 @@ package docker
import ( import (
"context" "context"
"errors"
"time" "time"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -19,7 +20,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError)
} }
func (c Client) Inspect(containerID string) (*Container, E.NestedError) { func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel() defer cancel()
json, err := c.ContainerInspect(ctx, containerID) json, err := c.ContainerInspect(ctx, containerID)

View file

@ -17,35 +17,36 @@ type builder struct {
} }
func NewBuilder(format string, args ...any) Builder { func NewBuilder(format string, args ...any) Builder {
return Builder{&builder{message: fmt.Sprintf(format, args...)}} if len(args) > 0 {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
}
return Builder{&builder{message: format}}
} }
// adding nil / nil is no-op, // adding nil / nil is no-op,
// you may safely pass expressions returning error to it. // you may safely pass expressions returning error to it.
func (b Builder) Add(err NestedError) Builder { func (b Builder) Add(err NestedError) {
if err != nil { if err != nil {
b.Lock() b.Lock()
b.errors = append(b.errors, err) b.errors = append(b.errors, err)
b.Unlock() b.Unlock()
} }
return b
} }
func (b Builder) AddE(err error) Builder { func (b Builder) AddE(err error) {
return b.Add(From(err)) b.Add(From(err))
} }
func (b Builder) Addf(format string, args ...any) Builder { func (b Builder) Addf(format string, args ...any) {
return b.Add(errorf(format, args...)) b.Add(errorf(format, args...))
} }
func (b Builder) AddRangeE(errs ...error) Builder { func (b Builder) AddRangeE(errs ...error) {
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
for _, err := range errs { for _, err := range errs {
b.AddE(err) b.AddE(err)
} }
return b
} }
// Build builds a NestedError based on the errors collected in the Builder. // Build builds a NestedError based on the errors collected in the Builder.

View file

@ -2,6 +2,7 @@ package error
import ( import (
stderrors "errors" stderrors "errors"
"fmt"
"reflect" "reflect"
) )
@ -16,6 +17,7 @@ var (
ErrOutOfRange = stderrors.New("out of range") ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error") ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch") ErrTypeMismatch = stderrors.New("type mismatch")
ErrPanicRecv = stderrors.New("panic")
) )
const fmtSubjectWhat = "%w %v: %q" const fmtSubjectWhat = "%w %v: %q"
@ -75,3 +77,7 @@ func TypeError2(subject any, from, to reflect.Value) NestedError {
func TypeMismatch[Expect any](value any) NestedError { func TypeMismatch[Expect any](value any) NestedError {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
} }
func PanicRecv(format string, args ...any) NestedError {
return errorf("%w%s", ErrPanicRecv, fmt.Sprintf(format, args...))
}

View file

@ -4,18 +4,20 @@ import (
"hash/fnv" "hash/fnv"
"net" "net"
"net/http" "net/http"
"sync"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
) )
type ipHash struct { type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware realIP *middleware.Middleware
pool servers
mu sync.Mutex
} }
func (lb *LoadBalancer) newIPHash() impl { func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb} impl := new(ipHash)
if len(lb.Options) == 0 { if len(lb.Options) == 0 {
return impl return impl
} }
@ -26,10 +28,37 @@ func (lb *LoadBalancer) newIPHash() impl {
} }
return impl return impl
} }
func (ipHash) OnAddServer(srv *Server) {}
func (ipHash) OnRemoveServer(srv *Server) {}
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) { func (impl *ipHash) OnAddServer(srv *Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
return
}
if s == nil {
impl.pool[i] = srv
return
}
}
impl.pool = append(impl.pool, srv)
}
func (impl *ipHash) OnRemoveServer(srv *Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
impl.pool[i] = nil
return
}
}
}
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil { if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r) impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else { } else {
@ -37,7 +66,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request)
} }
} }
func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr) ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError) http.Error(rw, "Internal error", http.StatusInternalServerError)
@ -45,10 +74,12 @@ func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
return return
} }
idx := hashIP(ip) % uint32(len(impl.pool)) idx := hashIP(ip) % uint32(len(impl.pool))
if impl.pool[idx].Status().Bad() {
srv := impl.pool[idx]
if srv == nil || srv.Status().Bad() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
} }
impl.pool[idx].ServeHTTP(rw, r) srv.ServeHTTP(rw, r)
} }
func hashIP(ip string) uint32 { func hashIP(ip string) uint32 {

View file

@ -5,8 +5,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/go-acme/lego/v4/log" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -28,7 +29,9 @@ type (
impl impl
*Config *Config
pool servers task task.Task
pool Pool
poolMu sync.Mutex poolMu sync.Mutex
sumWeight weightType sumWeight weightType
@ -41,11 +44,35 @@ type (
const maxWeight weightType = 100 const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer { func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)} lb := &LoadBalancer{
Config: new(Config),
pool: newPool(),
task: task.DummyTask(),
}
lb.UpdateConfigIfNeeded(cfg) lb.UpdateConfigIfNeeded(cfg)
return lb return lb
} }
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
lb.startTime = time.Now()
lb.task = routeSubtask
lb.task.OnComplete("loadbalancer cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v)
})
}
lb.pool.Clear()
})
return nil
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason string) {
lb.task.Finish(reason)
}
func (lb *LoadBalancer) updateImpl() { func (lb *LoadBalancer) updateImpl() {
switch lb.Mode { switch lb.Mode {
case Unset, RoundRobin: case Unset, RoundRobin:
@ -57,9 +84,9 @@ func (lb *LoadBalancer) updateImpl() {
default: // should happen in test only default: // should happen in test only
lb.impl = lb.newRoundRobin() lb.impl = lb.newRoundRobin()
} }
for _, srv := range lb.pool { lb.pool.RangeAll(func(_ string, srv *Server) {
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
} })
} }
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
@ -91,55 +118,60 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
lb.pool = append(lb.pool, srv) if lb.pool.Has(srv.Name) {
old, _ := lb.pool.Load(srv.Name)
lb.sumWeight -= old.Weight
lb.impl.OnRemoveServer(old)
}
lb.pool.Store(srv.Name, srv)
lb.sumWeight += srv.Weight lb.sumWeight += srv.Weight
lb.Rebalance() lb.rebalance()
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool)) logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
} }
func (lb *LoadBalancer) RemoveServer(srv *Server) { func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
lb.sumWeight -= srv.Weight if !lb.pool.Has(srv.Name) {
lb.Rebalance()
lb.impl.OnRemoveServer(srv)
for i, s := range lb.pool {
if s == srv {
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
break
}
}
if lb.IsEmpty() {
lb.pool = nil
return return
} }
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool)) lb.pool.Delete(srv.Name)
lb.sumWeight -= srv.Weight
lb.rebalance()
lb.impl.OnRemoveServer(srv)
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
logger.Infof("[remove] loadbalancer %s stopped", lb.Link)
return
}
logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
} }
func (lb *LoadBalancer) IsEmpty() bool { func (lb *LoadBalancer) rebalance() {
return len(lb.pool) == 0
}
func (lb *LoadBalancer) Rebalance() {
if lb.sumWeight == maxWeight { if lb.sumWeight == maxWeight {
return return
} }
if lb.pool.Size() == 0 {
return
}
if lb.sumWeight == 0 { // distribute evenly if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(len(lb.pool)) weightEach := maxWeight / weightType(lb.pool.Size())
remainder := maxWeight % weightType(len(lb.pool)) remainder := maxWeight % weightType(lb.pool.Size())
for _, s := range lb.pool { lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightEach s.Weight = weightEach
lb.sumWeight += weightEach lb.sumWeight += weightEach
if remainder > 0 { if remainder > 0 {
s.Weight++ s.Weight++
remainder-- remainder--
} }
} })
return return
} }
@ -147,18 +179,18 @@ func (lb *LoadBalancer) Rebalance() {
scaleFactor := float64(maxWeight) / float64(lb.sumWeight) scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0 lb.sumWeight = 0
for _, s := range lb.pool { lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightType(float64(s.Weight) * scaleFactor) s.Weight = weightType(float64(s.Weight) * scaleFactor)
lb.sumWeight += s.Weight lb.sumWeight += s.Weight
} })
delta := maxWeight - lb.sumWeight delta := maxWeight - lb.sumWeight
if delta == 0 { if delta == 0 {
return return
} }
for _, s := range lb.pool { lb.pool.Range(func(_ string, s *Server) bool {
if delta == 0 { if delta == 0 {
break return false
} }
if delta > 0 { if delta > 0 {
s.Weight++ s.Weight++
@ -169,7 +201,8 @@ func (lb *LoadBalancer) Rebalance() {
lb.sumWeight-- lb.sumWeight--
delta++ delta++
} }
} return true
})
} }
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
@ -181,23 +214,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
lb.impl.ServeHTTP(srvs, rw, r) lb.impl.ServeHTTP(srvs, rw, r)
} }
func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet")
}
lb.startTime = time.Now()
logger.Debugf("loadbalancer %s started", lb.Link)
}
func (lb *LoadBalancer) Stop() {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = nil
logger.Debugf("loadbalancer %s stopped", lb.Link)
}
func (lb *LoadBalancer) Uptime() time.Duration { func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime) return time.Since(lb.startTime)
} }
@ -205,9 +221,10 @@ func (lb *LoadBalancer) Uptime() time.Duration {
// MarshalJSON implements health.HealthMonitor. // MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) { func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any) extra := make(map[string]any)
for _, v := range lb.pool { lb.pool.RangeAll(func(k string, v *Server) {
extra[v.Name] = v.healthMon extra[v.Name] = v.healthMon
} })
return (&health.JSONRepresentation{ return (&health.JSONRepresentation{
Name: lb.Name(), Name: lb.Name(),
Status: lb.Status(), Status: lb.Status(),
@ -227,7 +244,7 @@ func (lb *LoadBalancer) Name() string {
// Status implements health.HealthMonitor. // Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status { func (lb *LoadBalancer) Status() health.Status {
if len(lb.pool) == 0 { if lb.pool.Size() == 0 {
return health.StatusUnknown return health.StatusUnknown
} }
if len(lb.availServers()) == 0 { if len(lb.availServers()) == 0 {
@ -241,21 +258,13 @@ func (lb *LoadBalancer) String() string {
return lb.Name() return lb.Name()
} }
func (lb *LoadBalancer) availServers() servers { func (lb *LoadBalancer) availServers() []*Server {
lb.poolMu.Lock() avail := make([]*Server, 0, lb.pool.Size())
defer lb.poolMu.Unlock() lb.pool.RangeAll(func(_ string, srv *Server) {
if srv.Status().Bad() {
avail := make(servers, 0, len(lb.pool)) return
for _, s := range lb.pool {
if s.Status().Bad() {
continue
} }
avail = append(avail, s) avail = append(avail, srv)
} })
return avail return avail
} }
// static HealthMonitor interface check
func (lb *LoadBalancer) _() health.HealthMonitor {
return lb
}

View file

@ -13,7 +13,7 @@ func TestRebalance(t *testing.T) {
for range 10 { for range 10 {
lb.AddServer(&Server{}) lb.AddServer(&Server{})
} }
lb.Rebalance() lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight) ExpectEqual(t, lb.sumWeight, maxWeight)
}) })
t.Run("less", func(t *testing.T) { t.Run("less", func(t *testing.T) {
@ -23,7 +23,7 @@ func TestRebalance(t *testing.T) {
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance() lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight) ExpectEqual(t, lb.sumWeight, maxWeight)
}) })
@ -36,7 +36,7 @@ func TestRebalance(t *testing.T) {
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)}) lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance() lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " "))) // t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight) ExpectEqual(t, lb.sumWeight, maxWeight)
}) })

View file

@ -6,6 +6,7 @@ import (
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -20,9 +21,12 @@ type (
handler http.Handler handler http.Handler
healthMon health.HealthMonitor healthMon health.HealthMonitor
} }
servers []*Server servers = []*Server
Pool = F.Map[string, *Server]
) )
var newPool = F.NewMap[Pool]
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server { func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
srv := &Server{ srv := &Server{
Name: name, Name: name,

View file

@ -48,11 +48,11 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
} }
delete(def, "use") delete(def, "use")
m, err := base.WithOptionsClone(def) m, err := base.WithOptionsClone(def)
m.name = fmt.Sprintf("%s[%d]", name, i)
if err != nil { if err != nil {
chainErr.Add(err.Subjectf("item%d", i)) chainErr.Add(err.Subjectf("item%d", i))
continue continue
} }
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m) chain = append(chain, m)
} }
if chainErr.HasError() { if chainErr.HasError() {

View file

@ -0,0 +1,19 @@
package types
import (
"fmt"
"net"
)
type Stream interface {
fmt.Stringer
Setup() error
Accept() (conn StreamConn, err error)
Handle(conn StreamConn) error
CloseListeners()
}
type StreamConn interface {
RemoteAddr() net.Addr
Close() error
}

View file

@ -1,177 +0,0 @@
package proxy
import (
"fmt"
"net/url"
"time"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
ReverseProxyEntry struct { // real model after validation
Raw *types.RawEntry `json:"raw"`
Alias T.Alias `json:"alias,omitempty"`
Scheme T.Scheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns T.PathPatterns `json:"path_patterns,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
Middlewares D.NestedLabelMap `json:"middlewares,omitempty"`
/* Docker only */
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopMethod T.StopMethod `json:"stop_method,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"`
StopSignal T.Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running,omitempty"`
}
StreamEntry struct {
Raw *types.RawEntry `json:"raw"`
Alias T.Alias `json:"alias,omitempty"`
Scheme T.StreamScheme `json:"scheme,omitempty"`
Host T.Host `json:"host,omitempty"`
Port T.StreamPort `json:"port,omitempty"`
Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
}
)
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.IsDocker()
}
func (rp *ReverseProxyEntry) UseLoadBalance() bool {
return rp.LoadBalance != nil && rp.LoadBalance.Link != ""
}
func (rp *ReverseProxyEntry) IsDocker() bool {
return rp.DockerHost != ""
}
func (rp *ReverseProxyEntry) IsZeroPort() bool {
return rp.URL.Port() == "0"
}
func (rp *ReverseProxyEntry) ShouldNotServe() bool {
return rp.IsZeroPort() && !rp.UseIdleWatcher()
}
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme)
if err != nil {
return nil, err
}
var entry any
e := E.NewBuilder("error validating entry")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err != nil {
return nil, err
}
return entry, nil
}
func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
cont := m.Container
if cont == nil {
cont = D.DummyContainer
}
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(cont.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(cont.StopSignal)
b.Add(err)
if err != nil {
return nil
}
return &ReverseProxyEntry{
Raw: m,
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
HealthCheck: &m.HealthCheck,
LoadBalance: &m.LoadBalance,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}
}
func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry {
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := T.ValidateStreamScheme(m.Scheme)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Raw: m,
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
Healthcheck: &m.HealthCheck,
}
}

View file

@ -0,0 +1,68 @@
package entry
import (
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Entry interface {
TargetName() string
TargetURL() net.URL
RawEntry() *RawEntry
LoadBalanceConfig() *loadbalancer.Config
HealthCheckConfig() *health.HealthCheckConfig
IdlewatcherConfig() *idlewatcher.Config
}
func ValidateEntry(m *RawEntry) (Entry, E.NestedError) {
m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme)
if err != nil {
return nil, err
}
var entry Entry
e := E.NewBuilder("error validating entry")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err != nil {
return nil, err
}
return entry, nil
}
func IsDocker(entry Entry) bool {
iw := entry.IdlewatcherConfig()
return iw != nil && iw.ContainerID != ""
}
func IsZeroPort(entry Entry) bool {
return entry.TargetURL().Port() == "0"
}
func ShouldNotServe(entry Entry) bool {
return IsZeroPort(entry) && !UseIdleWatcher(entry)
}
func UseLoadBalance(entry Entry) bool {
lb := entry.LoadBalanceConfig()
return lb != nil && lb.Link != ""
}
func UseIdleWatcher(entry Entry) bool {
iw := entry.IdlewatcherConfig()
return iw != nil && iw.IdleTimeout > 0
}
func UseHealthCheck(entry Entry) bool {
hc := entry.HealthCheckConfig()
return hc != nil && !hc.Disabled
}

View file

@ -1,4 +1,4 @@
package types package entry
import ( import (
"strconv" "strconv"
@ -21,16 +21,16 @@ type (
// raw entry object before validation // raw entry object before validation
// loaded from docker labels or yaml file // loaded from docker labels or yaml file
Alias string `json:"-" yaml:"-"` Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme,omitempty" yaml:"scheme"` Scheme string `json:"scheme,omitempty" yaml:"scheme"`
Host string `json:"host,omitempty" yaml:"host"` Host string `json:"host,omitempty" yaml:"host"`
Port string `json:"port,omitempty" yaml:"port"` Port string `json:"port,omitempty" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"` HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
LoadBalance loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"` LoadBalance *loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"`
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"` Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
/* Docker only */ /* Docker only */
Container *docker.Container `json:"container,omitempty" yaml:"-"` Container *docker.Container `json:"container,omitempty" yaml:"-"`
@ -122,29 +122,41 @@ func (e *RawEntry) FillMissingFields() {
} }
} }
if e.HealthCheck.Interval == 0 { if e.HealthCheck == nil {
e.HealthCheck.Interval = common.HealthCheckIntervalDefault e.HealthCheck = new(health.HealthCheckConfig)
} }
if e.HealthCheck.Timeout == 0 {
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault if e.HealthCheck.Disabled {
e.HealthCheck = nil
} else {
if e.HealthCheck.Interval == 0 {
e.HealthCheck.Interval = common.HealthCheckIntervalDefault
}
if e.HealthCheck.Timeout == 0 {
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
}
} }
if cont.IdleTimeout == "" {
cont.IdleTimeout = common.IdleTimeoutDefault if cont.IdleTimeout != "" {
} if cont.WakeTimeout == "" {
if cont.WakeTimeout == "" { cont.WakeTimeout = common.WakeTimeoutDefault
cont.WakeTimeout = common.WakeTimeoutDefault }
} if cont.StopTimeout == "" {
if cont.StopTimeout == "" { cont.StopTimeout = common.StopTimeoutDefault
cont.StopTimeout = common.StopTimeoutDefault }
} if cont.StopMethod == "" {
if cont.StopMethod == "" { cont.StopMethod = common.StopMethodDefault
cont.StopMethod = common.StopMethodDefault }
} }
e.Port = joinPorts(lp, pp, extra) e.Port = joinPorts(lp, pp, extra)
if e.Port == "" || e.Host == "" { if e.Port == "" || e.Host == "" {
e.Port = "0" if lp != "" {
e.Port = lp + ":0"
} else {
e.Port = "0"
}
} }
} }

View file

@ -0,0 +1,98 @@
package entry
import (
"fmt"
"net/url"
"github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type ReverseProxyEntry struct { // real model after validation
Raw *RawEntry `json:"raw"`
Alias fields.Alias `json:"alias,omitempty"`
Scheme fields.Scheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns fields.PathPatterns `json:"path_patterns,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty"`
/* Docker only */
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
}
func (rp *ReverseProxyEntry) TargetName() string {
return string(rp.Alias)
}
func (rp *ReverseProxyEntry) TargetURL() net.URL {
return rp.URL
}
func (rp *ReverseProxyEntry) RawEntry() *RawEntry {
return rp.Raw
}
func (rp *ReverseProxyEntry) LoadBalanceConfig() *loadbalancer.Config {
return rp.LoadBalance
}
func (rp *ReverseProxyEntry) HealthCheckConfig() *health.HealthCheckConfig {
return rp.HealthCheck
}
func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config {
return rp.Idlewatcher
}
func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry {
cont := m.Container
if cont == nil {
cont = docker.DummyContainer
}
lb := m.LoadBalance
if lb != nil && lb.Link == "" {
lb = nil
}
host, err := fields.ValidateHost(m.Host)
b.Add(err)
port, err := fields.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if err != nil {
return nil
}
return &ReverseProxyEntry{
Raw: m,
Alias: fields.NewAlias(m.Alias),
Scheme: s,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
HealthCheck: m.HealthCheck,
LoadBalance: lb,
Middlewares: m.Middlewares,
Idlewatcher: idleWatcherCfg,
}
}

View file

@ -0,0 +1,89 @@
package entry
import (
"fmt"
"github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type StreamEntry struct {
Raw *RawEntry `json:"raw"`
Alias fields.Alias `json:"alias,omitempty"`
Scheme fields.StreamScheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
Host fields.Host `json:"host,omitempty"`
Port fields.StreamPort `json:"port,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
/* Docker only */
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
}
func (s *StreamEntry) TargetName() string {
return string(s.Alias)
}
func (s *StreamEntry) TargetURL() net.URL {
return s.URL
}
func (s *StreamEntry) RawEntry() *RawEntry {
return s.Raw
}
func (s *StreamEntry) LoadBalanceConfig() *loadbalancer.Config {
// TODO: support stream load balance
return nil
}
func (s *StreamEntry) HealthCheckConfig() *health.HealthCheckConfig {
return s.HealthCheck
}
func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config {
return s.Idlewatcher
}
func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry {
cont := m.Container
if cont == nil {
cont = docker.DummyContainer
}
host, err := fields.ValidateHost(m.Host)
b.Add(err)
port, err := fields.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := fields.ValidateStreamScheme(m.Scheme)
b.Add(err)
url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Raw: m,
Alias: fields.NewAlias(m.Alias),
Scheme: *scheme,
URL: url,
Host: host,
Port: port,
HealthCheck: m.HealthCheck,
Idlewatcher: idleWatcherCfg,
}
}

View file

@ -1,17 +0,0 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View file

@ -1,23 +0,0 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View file

@ -1,18 +0,0 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

View file

@ -9,23 +9,21 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/errorpage" "github.com/yusing/go-proxy/internal/api/v1/errorpage"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker/idlewatcher" "github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
url "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/proxy/entry"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields" PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types" "github.com/yusing/go-proxy/internal/task"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
type ( type (
HTTPRoute struct { HTTPRoute struct {
*P.ReverseProxyEntry *entry.ReverseProxyEntry
HealthMon health.HealthMonitor `json:"health,omitempty"` HealthMon health.HealthMonitor `json:"health,omitempty"`
@ -33,6 +31,8 @@ type (
server *loadbalancer.Server server *loadbalancer.Server
handler http.Handler handler http.Handler
rp *gphttp.ReverseProxy rp *gphttp.ReverseProxy
task task.Task
} }
SubdomainKey = PT.Alias SubdomainKey = PT.Alias
@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
} }
} }
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
var trans *http.Transport var trans *http.Transport
if entry.NoTLSVerify { if entry.NoTLSVerify {
@ -84,12 +84,10 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
} }
} }
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
r := &HTTPRoute{ r := &HTTPRoute{
ReverseProxyEntry: entry, ReverseProxyEntry: entry,
rp: rp, rp: rp,
task: task.DummyTask(),
} }
return r, nil return r, nil
} }
@ -98,39 +96,34 @@ func (r *HTTPRoute) String() string {
return string(r.Alias) return string(r.Alias)
} }
func (r *HTTPRoute) URL() url.URL { // Start implements task.TaskStarter.
return r.ReverseProxyEntry.URL func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
} if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
func (r *HTTPRoute) Start() E.NestedError {
if r.ShouldNotServe() {
return nil return nil
} }
httpRoutesMu.Lock() httpRoutesMu.Lock()
defer httpRoutesMu.Unlock() defer httpRoutesMu.Unlock()
if r.handler != nil { if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
return nil
}
if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.HealthCheck.Disabled = true r.HealthCheck.Disabled = true
} }
switch { switch {
case r.UseIdleWatcher(): case entry.UseIdleWatcher(r):
watcher, err := idlewatcher.Register(r.ReverseProxyEntry) wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
if err != nil { if err != nil {
return err return err
} }
waker := idlewatcher.NewWaker(watcher, r.rp)
r.handler = waker r.handler = waker
r.HealthMon = waker r.HealthMon = waker
case !r.HealthCheck.Disabled: case entry.UseHealthCheck(r):
r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask(r.String()), r.URL(), r.HealthCheck) r.HealthMon = health.NewHTTPHealthMonitor(r.TargetURL(), r.HealthCheck, r.rp.Transport)
} }
r.task = providerSubtask
if r.handler == nil { if r.handler == nil {
switch { switch {
@ -146,44 +139,26 @@ func (r *HTTPRoute) Start() E.NestedError {
} }
if r.HealthMon != nil { if r.HealthMon != nil {
r.HealthMon.Start() if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn(E.FailWith("health monitor", err))
}
} }
if r.UseLoadBalance() { if entry.UseLoadBalance(r) {
r.addToLoadBalancer() r.addToLoadBalancer()
} else { } else {
httpRoutes.Store(string(r.Alias), r) httpRoutes.Store(string(r.Alias), r)
r.task.OnComplete("stop rp", func() {
httpRoutes.Delete(string(r.Alias))
})
} }
return nil return nil
} }
func (r *HTTPRoute) Stop() (_ E.NestedError) { // Finish implements task.TaskFinisher.
if r.handler == nil { func (r *HTTPRoute) Finish(reason string) {
return r.task.Finish(reason)
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.loadBalancer != nil {
r.removeFromLoadBalancer()
} else {
httpRoutes.Delete(string(r.Alias))
}
if r.HealthMon != nil {
r.HealthMon.Stop()
r.HealthMon = nil
}
r.handler = nil
return
}
func (r *HTTPRoute) Started() bool {
return r.handler != nil
} }
func (r *HTTPRoute) addToLoadBalancer() { func (r *HTTPRoute) addToLoadBalancer() {
@ -197,10 +172,14 @@ func (r *HTTPRoute) addToLoadBalancer() {
} }
} else { } else {
lb = loadbalancer.New(r.LoadBalance) lb = loadbalancer.New(r.LoadBalance)
lb.Start() lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
lbTask.OnComplete("remove lb from routes", func() {
httpRoutes.Delete(r.LoadBalance.Link)
})
lb.Start(lbTask)
linked = &HTTPRoute{ linked = &HTTPRoute{
ReverseProxyEntry: &P.ReverseProxyEntry{ ReverseProxyEntry: &entry.ReverseProxyEntry{
Raw: &types.RawEntry{ Raw: &entry.RawEntry{
Homepage: r.Raw.Homepage, Homepage: r.Raw.Homepage,
}, },
Alias: PT.Alias(lb.Link), Alias: PT.Alias(lb.Link),
@ -214,16 +193,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
r.loadBalancer = lb r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server) lb.AddServer(r.server)
} r.task.OnComplete("remove server from lb", func() {
lb.RemoveServer(r.server)
func (r *HTTPRoute) removeFromLoadBalancer() { })
r.loadBalancer.RemoveServer(r.server)
if r.loadBalancer.IsEmpty() {
httpRoutes.Delete(r.LoadBalance.Link)
logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link)
}
r.server = nil
r.loadBalancer = nil
} }
func ProxyHandler(w http.ResponseWriter, r *http.Request) { func ProxyHandler(w http.ResponseWriter, r *http.Request) {

View file

@ -10,10 +10,9 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
) )
type DockerProvider struct { type DockerProvider struct {
@ -43,7 +42,7 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
routes = R.NewRoutes() routes = R.NewRoutes()
entries := types.NewProxyEntries() entries := entry.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true) info, err := D.GetClientInfo(p.dockerHost, true)
if err != nil { if err != nil {
@ -66,12 +65,12 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
// there may be some valid entries in `en` // there may be some valid entries in `en`
dups := entries.MergeFrom(newEntries) dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error // add the duplicate proxy entries to the error
dups.RangeAll(func(k string, v *types.RawEntry) { dups.RangeAll(func(k string, v *entry.RawEntry) {
errors.Addf("duplicate alias %s", k) errors.Addf("duplicate alias %s", k)
}) })
} }
entries.RangeAll(func(_ string, e *types.RawEntry) { entries.RangeAll(func(_ string, e *entry.RawEntry) {
e.Container.DockerHost = p.dockerHost e.Container.DockerHost = p.dockerHost
}) })
@ -88,85 +87,10 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
strings.HasSuffix(container.ContainerName, "-old") strings.HasSuffix(container.ContainerName, "-old")
} }
func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
matches := R.NewRoutes()
oldRoutes.RangeAllParallel(func(k string, v *R.Route) {
if v.Entry.Container.ContainerID == event.ActorID ||
v.Entry.Container.ContainerName == event.ActorName {
matches.Store(k, v)
}
})
//FIXME: docker event die stuck
var newRoutes R.Routes
var err E.NestedError
switch {
// id & container name changed
case matches.Size() == 0:
matches = oldRoutes
newRoutes, err = p.LoadRoutesImpl()
b.Add(err)
case event.Action == events.ActionContainerDestroy:
// stop all old routes
matches.RangeAllParallel(func(_ string, v *R.Route) {
oldRoutes.Delete(v.Entry.Alias)
b.Add(v.Stop())
res.nRemoved++
})
return
default:
cont, err := D.Inspect(p.dockerHost, event.ActorID)
if err != nil {
b.Add(E.FailWith("inspect container", err))
return
}
if p.shouldIgnore(cont) {
// stop all old routes
matches.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
res.nRemoved++
})
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
newRoutes, err = R.FromEntries(entries)
b.Add(err)
}
matches.RangeAll(func(k string, v *R.Route) {
if !newRoutes.Has(k) && !oldRoutes.Has(k) {
b.Add(v.Stop())
matches.Delete(k)
res.nRemoved++
}
})
newRoutes.RangeAll(func(alias string, newRoute *R.Route) {
oldRoute, exists := oldRoutes.Load(alias)
if exists {
b.Add(oldRoute.Stop())
res.nReloaded++
} else {
res.nAdded++
}
b.Add(newRoute.Start())
oldRoutes.Store(alias, newRoute)
})
return
}
// Returns a list of proxy entries for a container. // Returns a list of proxy entries for a container.
// Always non-nil. // Always non-nil.
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries types.RawEntries, _ E.NestedError) { func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
entries = types.NewProxyEntries() entries = entry.NewProxyEntries()
if p.shouldIgnore(container) { if p.shouldIgnore(container) {
return return
@ -174,7 +98,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
// init entries map for all aliases // init entries map for all aliases
for _, a := range container.Aliases { for _, a := range container.Aliases {
entries.Store(a, &types.RawEntry{ entries.Store(a, &entry.RawEntry{
Alias: a, Alias: a,
Container: container, Container: container,
}) })
@ -186,14 +110,14 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
} }
// remove all entries that failed to fill in missing fields // remove all entries that failed to fill in missing fields
entries.RangeAll(func(_ string, re *types.RawEntry) { entries.RangeAll(func(_ string, re *entry.RawEntry) {
re.FillMissingFields() re.FillMissingFields()
}) })
return entries, errors.Build().Subject(container.ContainerName) return entries, errors.Build().Subject(container.ContainerName)
} }
func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEntries, key, val string) (res E.NestedError) { func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
b := E.NewBuilder("errors in label %s", key) b := E.NewBuilder("errors in label %s", key)
defer b.To(&res) defer b.To(&res)
@ -220,7 +144,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt
} }
if lbl.Target == D.WildcardAlias { if lbl.Target == D.WildcardAlias {
// apply label for all aliases // apply label for all aliases
entries.RangeAll(func(a string, e *types.RawEntry) { entries.RangeAll(func(a string, e *entry.RawEntry) {
if err = D.ApplyLabel(e, lbl); err != nil { if err = D.ApplyLabel(e, lbl); err != nil {
b.Add(err) b.Add(err)
} }

View file

@ -10,7 +10,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy" "github.com/yusing/go-proxy/internal/proxy/entry"
T "github.com/yusing/go-proxy/internal/proxy/fields" T "github.com/yusing/go-proxy/internal/proxy/fields"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -46,7 +46,7 @@ func TestApplyLabelWildcard(t *testing.T) {
Names: dummyNames, Names: dummyNames,
Labels: map[string]string{ Labels: map[string]string{
D.LabelAliases: "a,b", D.LabelAliases: "a,b",
D.LabelIdleTimeout: common.IdleTimeoutDefault, D.LabelIdleTimeout: "",
D.LabelStopMethod: common.StopMethodDefault, D.LabelStopMethod: common.StopMethodDefault,
D.LabelStopSignal: "SIGTERM", D.LabelStopSignal: "SIGTERM",
D.LabelStopTimeout: common.StopTimeoutDefault, D.LabelStopTimeout: common.StopTimeoutDefault,
@ -62,7 +62,7 @@ func TestApplyLabelWildcard(t *testing.T) {
"proxy.a.middlewares.middleware2.prop3": "value3", "proxy.a.middlewares.middleware2.prop3": "value3",
"proxy.a.middlewares.middleware2.prop4": "value4", "proxy.a.middlewares.middleware2.prop4": "value4",
}, },
}, "")) }, client.DefaultDockerHost))
ExpectNoError(t, err.Error()) ExpectNoError(t, err.Error())
a, ok := entries.Load("a") a, ok := entries.Load("a")
@ -88,8 +88,8 @@ func TestApplyLabelWildcard(t *testing.T) {
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect) ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0) ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault) ExpectEqual(t, a.Container.IdleTimeout, "")
ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault) ExpectEqual(t, b.Container.IdleTimeout, "")
ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault) ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault) ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault)
@ -107,6 +107,7 @@ func TestApplyLabelWildcard(t *testing.T) {
func TestApplyLabelWithAlias(t *testing.T) { func TestApplyLabelWithAlias(t *testing.T) {
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames, Names: dummyNames,
State: "running",
Labels: map[string]string{ Labels: map[string]string{
D.LabelAliases: "a,b,c", D.LabelAliases: "a,b,c",
"proxy.a.no_tls_verify": "true", "proxy.a.no_tls_verify": "true",
@ -114,7 +115,7 @@ func TestApplyLabelWithAlias(t *testing.T) {
"proxy.b.port": "1234", "proxy.b.port": "1234",
"proxy.c.scheme": "https", "proxy.c.scheme": "https",
}, },
}, "")) }, client.DefaultDockerHost))
a, ok := entries.Load("a") a, ok := entries.Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
b, ok := entries.Load("b") b, ok := entries.Load("b")
@ -134,6 +135,7 @@ func TestApplyLabelWithAlias(t *testing.T) {
func TestApplyLabelWithRef(t *testing.T) { func TestApplyLabelWithRef(t *testing.T) {
entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames, Names: dummyNames,
State: "running",
Labels: map[string]string{ Labels: map[string]string{
D.LabelAliases: "a,b,c", D.LabelAliases: "a,b,c",
"proxy.#1.host": "localhost", "proxy.#1.host": "localhost",
@ -142,7 +144,7 @@ func TestApplyLabelWithRef(t *testing.T) {
"proxy.#3.port": "1111", "proxy.#3.port": "1111",
"proxy.#3.scheme": "https", "proxy.#3.scheme": "https",
}, },
}, ""))) }, client.DefaultDockerHost)))
a, ok := entries.Load("a") a, ok := entries.Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
b, ok := entries.Load("b") b, ok := entries.Load("b")
@ -161,6 +163,7 @@ func TestApplyLabelWithRef(t *testing.T) {
func TestApplyLabelWithRefIndexError(t *testing.T) { func TestApplyLabelWithRefIndexError(t *testing.T) {
c := D.FromDocker(&types.Container{ c := D.FromDocker(&types.Container{
Names: dummyNames, Names: dummyNames,
State: "running",
Labels: map[string]string{ Labels: map[string]string{
D.LabelAliases: "a,b", D.LabelAliases: "a,b",
"proxy.#1.host": "localhost", "proxy.#1.host": "localhost",
@ -173,6 +176,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
_, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{ _, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames, Names: dummyNames,
State: "running",
Labels: map[string]string{ Labels: map[string]string{
D.LabelAliases: "a,b", D.LabelAliases: "a,b",
"proxy.#0.host": "localhost", "proxy.#0.host": "localhost",
@ -183,7 +187,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
} }
func TestPublicIPLocalhost(t *testing.T) { func TestPublicIPLocalhost(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost) c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, client.DefaultDockerHost)
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1") ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1")
@ -191,7 +195,7 @@ func TestPublicIPLocalhost(t *testing.T) {
} }
func TestPublicIPRemote(t *testing.T) { func TestPublicIPRemote(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375") c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
@ -218,6 +222,7 @@ func TestPrivateIPLocalhost(t *testing.T) {
func TestPrivateIPRemote(t *testing.T) { func TestPrivateIPRemote(t *testing.T) {
c := D.FromDocker(&types.Container{ c := D.FromDocker(&types.Container{
Names: dummyNames, Names: dummyNames,
State: "running",
NetworkSettings: &types.SummaryNetworkSettings{ NetworkSettings: &types.SummaryNetworkSettings{
Networks: map[string]*network.EndpointSettings{ Networks: map[string]*network.EndpointSettings{
"network": { "network": {
@ -239,6 +244,7 @@ func TestStreamDefaultValues(t *testing.T) {
privIP := "172.17.0.123" privIP := "172.17.0.123"
cont := &types.Container{ cont := &types.Container{
Names: []string{"a"}, Names: []string{"a"},
State: "running",
NetworkSettings: &types.SummaryNetworkSettings{ NetworkSettings: &types.SummaryNetworkSettings{
Networks: map[string]*network.EndpointSettings{ Networks: map[string]*network.EndpointSettings{
"network": { "network": {
@ -256,9 +262,8 @@ func TestStreamDefaultValues(t *testing.T) {
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
entry := Must(P.ValidateEntry(raw)) en := Must(entry.ValidateEntry(raw))
a := ExpectType[*entry.StreamEntry](t, en)
a := ExpectType[*P.StreamEntry](t, entry)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Host, T.Host(privIP)) ExpectEqual(t, a.Host, T.Host(privIP))
@ -270,9 +275,8 @@ func TestStreamDefaultValues(t *testing.T) {
c := D.FromDocker(cont, "tcp://1.2.3.4:2375") c := D.FromDocker(cont, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
entry := Must(P.ValidateEntry(raw)) en := Must(entry.ValidateEntry(raw))
a := ExpectType[*entry.StreamEntry](t, en)
a := ExpectType[*P.StreamEntry](t, entry)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Host, "1.2.3.4") ExpectEqual(t, a.Host, "1.2.3.4")

View file

@ -0,0 +1,109 @@
package provider
import (
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher"
)
type EventHandler struct {
provider *Provider
added []string
removed []string
paused []string
updated []string
errs E.Builder
}
func (provider *Provider) newEventHandler() *EventHandler {
return &EventHandler{
provider: provider,
errs: E.NewBuilder("event errors"),
}
}
func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
oldRoutes := handler.provider.routes
newRoutes, err := handler.provider.LoadRoutesImpl()
if err != nil {
handler.errs.Add(err.Subject("load routes"))
return
}
oldRoutes.RangeAll(func(k string, v *route.Route) {
if !newRoutes.Has(k) {
handler.Remove(v)
}
})
newRoutes.RangeAll(func(k string, newr *route.Route) {
if oldRoutes.Has(k) {
for _, ev := range events {
if handler.match(ev, newr) {
old, ok := oldRoutes.Load(k)
if !ok { // should not happen
panic("race condition")
}
handler.Update(parent, old, newr)
return
}
}
} else {
handler.Add(parent, newr)
}
})
}
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.t {
case ProviderTypeDocker:
return route.Entry.Container.ContainerID == event.ActorID ||
route.Entry.Container.ContainerName == event.ActorName
case ProviderTypeFile:
return true
}
// should never happen
return false
}
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
err := handler.provider.startRoute(parent, route)
if err != nil {
handler.errs.Add(err)
} else {
handler.added = append(handler.added, route.Entry.Alias)
}
}
func (handler *EventHandler) Remove(route *route.Route) {
route.Finish("route removal")
handler.removed = append(handler.removed, route.Entry.Alias)
}
func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) {
oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute)
if err != nil {
handler.errs.Add(err)
} else {
handler.updated = append(handler.updated, newRoute.Entry.Alias)
}
}
func (handler *EventHandler) Log() {
results := E.NewBuilder("event occured")
for _, alias := range handler.added {
results.Addf("added %s", alias)
}
for _, alias := range handler.removed {
results.Addf("removed %s", alias)
}
for _, alias := range handler.updated {
results.Addf("updated %s", alias)
}
results.Add(handler.errs.Build())
if result := results.Build(); result != nil {
handler.provider.l.Info(result)
}
}

View file

@ -7,8 +7,8 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
) )
@ -42,38 +42,13 @@ func (p FileProvider) String() string {
return p.fileName return p.fileName
} }
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err != nil {
b.Add(err)
return
}
res.nRemoved = newRoutes.Size()
routes.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Start())
})
res.nAdded = newRoutes.Size()
routes.MergeFrom(newRoutes)
return
}
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) { func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
routes = R.NewRoutes() routes = R.NewRoutes()
b := E.NewBuilder("file %q validation failure", p.fileName) b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res) defer b.To(&res)
entries := types.NewProxyEntries() entries := entry.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path)) data, err := E.Check(os.ReadFile(p.path))
if err != nil { if err != nil {

View file

@ -1,14 +1,16 @@
package provider package provider
import ( import (
"context" "fmt"
"path" "path"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
) )
type ( type (
@ -19,18 +21,14 @@ type (
t ProviderType t ProviderType
routes R.Routes routes R.Routes
watcher W.Watcher watcher W.Watcher
watcherTask common.Task
watcherCancel context.CancelFunc
l *logrus.Entry l *logrus.Entry
} }
ProviderImpl interface { ProviderImpl interface {
fmt.Stringer
NewWatcher() W.Watcher NewWatcher() W.Watcher
// even returns error, routes must be non-nil
LoadRoutesImpl() (R.Routes, E.NestedError) LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event W.Event, routes R.Routes) EventResult
String() string
} }
ProviderType string ProviderType string
ProviderStats struct { ProviderStats struct {
@ -38,17 +36,13 @@ type (
NumStreams int `json:"num_streams"` NumStreams int `json:"num_streams"`
Type ProviderType `json:"type"` Type ProviderType `json:"type"`
} }
EventResult struct {
nAdded int
nRemoved int
nReloaded int
err E.NestedError
}
) )
const ( const (
ProviderTypeDocker ProviderType = "docker" ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file" ProviderTypeFile ProviderType = "file"
providerEventFlushInterval = 500 * time.Millisecond
) )
func newProvider(name string, t ProviderType) *Provider { func newProvider(name string, t ProviderType) *Provider {
@ -106,32 +100,48 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil return []byte(p.String()), nil
} }
func (p *Provider) StartAllRoutes() (res E.NestedError) { func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
subtask := parent.Subtask("route %s", r.Entry.Alias)
err := r.Start(subtask)
if err != nil {
p.routes.Delete(r.Entry.Alias)
subtask.Finish(err.String()) // just to ensure
return err
} else {
subtask.OnComplete("del from provider", func() {
p.routes.Delete(r.Entry.Alias)
})
}
return nil
}
// Start implements task.TaskStarter.
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
errors := E.NewBuilder("errors starting routes") errors := E.NewBuilder("errors starting routes")
defer errors.To(&res) defer errors.To(&res)
// start watcher no matter load success or not // routes and event queue will stop on parent cancel
go p.watchEvents() providerTask := configSubtask
p.routes.RangeAllParallel(func(alias string, r *R.Route) { p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Start().Subject(r)) errors.Add(p.startRoute(providerTask, r))
}) })
return
}
func (p *Provider) StopAllRoutes() (res E.NestedError) { eventQueue := events.NewEventQueue(
if p.watcherCancel != nil { providerTask,
p.watcherCancel() providerEventFlushInterval,
p.watcherCancel = nil func(flushTask task.Task, events []events.Event) {
} handler := p.newEventHandler()
// routes' lifetime should follow the provider's lifetime
errors := E.NewBuilder("errors stopping routes") handler.Handle(providerTask, events)
defer errors.To(&res) handler.Log()
flushTask.Finish("events flushed")
p.routes.RangeAllParallel(func(alias string, r *R.Route) { },
errors.Add(r.Stop().Subject(r)) func(err E.NestedError) {
}) p.l.Error(err)
p.routes.Clear() },
)
eventQueue.Start(p.watcher.Events(providerTask.Context()))
return return
} }
@ -147,7 +157,6 @@ func (p *Provider) LoadRoutes() E.NestedError {
var err E.NestedError var err E.NestedError
p.routes, err = p.LoadRoutesImpl() p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 { if p.routes.Size() > 0 {
p.l.Infof("loaded %d routes", p.routes.Size())
return err return err
} }
if err == nil { if err == nil {
@ -156,13 +165,14 @@ func (p *Provider) LoadRoutes() E.NestedError {
return E.FailWith("loading routes", err) return E.FailWith("loading routes", err)
} }
func (p *Provider) NumRoutes() int {
return p.routes.Size()
}
func (p *Provider) Statistics() ProviderStats { func (p *Provider) Statistics() ProviderStats {
numRPs := 0 numRPs := 0
numStreams := 0 numStreams := 0
p.routes.RangeAll(func(_ string, r *R.Route) { p.routes.RangeAll(func(_ string, r *R.Route) {
if !r.Started() {
return
}
switch r.Type { switch r.Type {
case R.RouteTypeReverseProxy: case R.RouteTypeReverseProxy:
numRPs++ numRPs++
@ -176,34 +186,3 @@ func (p *Provider) Statistics() ProviderStats {
Type: p.t, Type: p.t,
} }
} }
func (p *Provider) watchEvents() {
p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name)
defer p.watcherTask.Finished()
events, errs := p.watcher.Events(p.watcherTask.Context())
l := p.l.WithField("module", "watcher")
for {
select {
case <-p.watcherTask.Context().Done():
return
case event := <-events:
task := p.watcherTask.Subtask("%s event %s", event.Type, event)
l.Infof("%s event %q", event.Type, event)
res := p.OnEvent(event, p.routes)
task.Finished()
if res.nAdded+res.nRemoved+res.nReloaded > 0 {
l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded)
}
if res.err != nil {
l.Error(res.err)
}
case err := <-errs:
if err == nil || err.Is(context.Canceled) {
continue
}
l.Errorf("watcher error: %s", err)
}
}
}

View file

@ -4,8 +4,8 @@ import (
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types" url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy" "github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/types" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
@ -16,16 +16,16 @@ type (
_ U.NoCopy _ U.NoCopy
impl impl
Type RouteType Type RouteType
Entry *types.RawEntry Entry *entry.RawEntry
} }
Routes = F.Map[string, *Route] Routes = F.Map[string, *Route]
impl interface { impl interface {
Start() E.NestedError entry.Entry
Stop() E.NestedError task.TaskStarter
Started() bool task.TaskFinisher
String() string String() string
URL() url.URL TargetURL() url.URL
} }
) )
@ -44,8 +44,8 @@ func (rt *Route) Container() *docker.Container {
return rt.Entry.Container return rt.Entry.Container
} }
func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
entry, err := P.ValidateEntry(en) en, err := entry.ValidateEntry(raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -53,11 +53,11 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
var t RouteType var t RouteType
var rt impl var rt impl
switch e := entry.(type) { switch e := en.(type) {
case *P.StreamEntry: case *entry.StreamEntry:
t = RouteTypeStream t = RouteTypeStream
rt, err = NewStreamRoute(e) rt, err = NewStreamRoute(e)
case *P.ReverseProxyEntry: case *entry.ReverseProxyEntry:
t = RouteTypeReverseProxy t = RouteTypeReverseProxy
rt, err = NewHTTPRoute(e) rt, err = NewHTTPRoute(e)
default: default:
@ -69,19 +69,21 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
return &Route{ return &Route{
impl: rt, impl: rt,
Type: t, Type: t,
Entry: en, Entry: raw,
}, nil }, nil
} }
func FromEntries(entries types.RawEntries) (Routes, E.NestedError) { func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes") b := E.NewBuilder("errors in routes")
routes := NewRoutes() routes := NewRoutes()
entries.RangeAll(func(alias string, entry *types.RawEntry) { entries.RangeAllParallel(func(alias string, en *entry.RawEntry) {
entry.Alias = alias en.Alias = alias
r, err := NewRoute(entry) r, err := NewRoute(en)
if err != nil { if err != nil {
b.Add(err.Subject(alias)) b.Add(err.Subject(alias))
} else if entry.ShouldNotServe(r) {
return
} else { } else {
routes.Store(alias, r) routes.Store(alias, r)
} }

View file

@ -4,169 +4,141 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net" stdNet "net"
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types" net "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy" "github.com/yusing/go-proxy/internal/proxy/entry"
PT "github.com/yusing/go-proxy/internal/proxy/fields" "github.com/yusing/go-proxy/internal/task"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
type StreamRoute struct { type StreamRoute struct {
*P.StreamEntry *entry.StreamEntry
StreamImpl `json:"-"` net.Stream `json:"-"`
HealthMon health.HealthMonitor `json:"health"` HealthMon health.HealthMonitor `json:"health"`
url url.URL task task.Task
task common.Task
cancel context.CancelFunc
done chan struct{}
l logrus.FieldLogger l logrus.FieldLogger
mu sync.Mutex
} }
type StreamImpl interface { var (
Setup() error streamRoutes = F.NewMapOf[string, *StreamRoute]()
Accept() (any, error) streamRoutesMu sync.Mutex
Handle(conn any) error )
CloseListeners()
String() string
}
var streamRoutes = F.NewMapOf[string, *StreamRoute]()
func GetStreamProxies() F.Map[string, *StreamRoute] { func GetStreamProxies() F.Map[string, *StreamRoute] {
return streamRoutes return streamRoutes
} }
func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
// TODO: support non-coherent scheme // TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() { if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
} }
url, err := url.ParseURL(fmt.Sprintf("%s://%s:%d", entry.Scheme.ProxyScheme, entry.Host, entry.Port.ProxyPort)) return &StreamRoute{
if err != nil {
// !! should not happen
panic(err)
}
base := &StreamRoute{
StreamEntry: entry, StreamEntry: entry,
url: url, task: task.DummyTask(),
} }, nil
if entry.Scheme.ListeningScheme.IsTCP() { }
base.StreamImpl = NewTCPRoute(base)
} else { func (r *StreamRoute) Finish(reason string) {
base.StreamImpl = NewUDPRoute(base) r.task.Finish(reason)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, nil
} }
func (r *StreamRoute) String() string { func (r *StreamRoute) String() string {
return fmt.Sprintf("stream %s", r.Alias) return fmt.Sprintf("stream %s", r.Alias)
} }
func (r *StreamRoute) URL() url.URL { // Start implements task.TaskStarter.
return r.url func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
} if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
func (r *StreamRoute) Start() E.NestedError {
r.mu.Lock()
defer r.mu.Unlock()
if r.Port.ProxyPort == PT.NoPort || r.task != nil {
return nil return nil
} }
r.task, r.cancel = common.NewTaskWithCancel(r.String())
streamRoutesMu.Lock()
defer streamRoutesMu.Unlock()
if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.HealthCheck.Disabled = true
}
if r.Scheme.ListeningScheme.IsTCP() {
r.Stream = NewTCPRoute(r)
} else {
r.Stream = NewUDPRoute(r)
}
r.l = logrus.WithField("route", r.Stream.String())
switch {
case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
if err != nil {
return err
}
r.Stream = waker
r.HealthMon = waker
case entry.UseHealthCheck(r):
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
}
r.task = providerSubtask
r.task.OnComplete("stop stream", r.CloseListeners)
if err := r.Setup(); err != nil { if err := r.Setup(); err != nil {
return E.FailWith("setup", err) return E.FailWith("setup", err)
} }
r.done = make(chan struct{})
r.l.Infof("listening on port %d", r.Port.ListeningPort) r.l.Infof("listening on port %d", r.Port.ListeningPort)
go r.acceptConnections() go r.acceptConnections()
if !r.Healthcheck.Disabled {
r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck) if r.HealthMon != nil {
r.HealthMon.Start() r.HealthMon.Start(r.task.Subtask("health monitor"))
} }
streamRoutes.Store(string(r.Alias), r) streamRoutes.Store(string(r.Alias), r)
return nil return nil
} }
func (r *StreamRoute) Stop() E.NestedError {
r.mu.Lock()
defer r.mu.Unlock()
if r.task == nil {
return nil
}
streamRoutes.Delete(string(r.Alias))
if r.HealthMon != nil {
r.HealthMon.Stop()
r.HealthMon = nil
}
r.cancel()
r.CloseListeners()
<-r.done
return nil
}
func (r *StreamRoute) Started() bool {
return r.task != nil
}
func (r *StreamRoute) acceptConnections() { func (r *StreamRoute) acceptConnections() {
var connWg sync.WaitGroup
task := r.task.Subtask("%s accept connections", r.String())
defer func() {
connWg.Wait()
task.Finished()
r.task.Finished()
r.task, r.cancel = nil, nil
close(r.done)
r.done = nil
}()
for { for {
select { select {
case <-task.Context().Done(): case <-r.task.Context().Done():
return return
default: default:
conn, err := r.Accept() conn, err := r.Accept()
if err != nil { if err != nil {
select { select {
case <-task.Context().Done(): case <-r.task.Context().Done():
return return
default: default:
var nErr *net.OpError var nErr *stdNet.OpError
ok := errors.As(err, &nErr) ok := errors.As(err, &nErr)
if !(ok && nErr.Timeout()) { if !(ok && nErr.Timeout()) {
r.l.Error(err) r.l.Error("accept connection error: ", err)
r.task.Finish(err.Error())
return
} }
continue continue
} }
} }
connWg.Add(1) connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
go func() { go func() {
err := r.Handle(conn) err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err) r.l.Error(err)
connTask.Finish(err.Error())
} else {
connTask.Finish("connection closed")
} }
connWg.Done() conn.Close()
}() }()
} }
} }

View file

@ -6,6 +6,7 @@ import (
"net" "net"
"time" "time"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields" T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
@ -21,7 +22,7 @@ type (
} }
) )
func NewTCPRoute(base *StreamRoute) StreamImpl { func NewTCPRoute(base *StreamRoute) *TCPRoute {
return &TCPRoute{StreamRoute: base} return &TCPRoute{StreamRoute: base}
} }
@ -36,19 +37,16 @@ func (route *TCPRoute) Setup() error {
return nil return nil
} }
func (route *TCPRoute) Accept() (any, error) { func (route *TCPRoute) Accept() (types.StreamConn, error) {
route.listener.SetDeadline(time.Now().Add(time.Second)) route.listener.SetDeadline(time.Now().Add(time.Second))
return route.listener.Accept() return route.listener.Accept()
} }
func (route *TCPRoute) Handle(c any) error { func (route *TCPRoute) Handle(c types.StreamConn) error {
clientConn := c.(net.Conn) clientConn := c.(net.Conn)
defer clientConn.Close() defer clientConn.Close()
go func() { route.task.OnComplete("close conn", func() { clientConn.Close() })
<-route.task.Context().Done()
clientConn.Close()
}()
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
@ -70,5 +68,4 @@ func (route *TCPRoute) CloseListeners() {
return return
} }
route.listener.Close() route.listener.Close()
route.listener = nil
} }

View file

@ -1,11 +1,13 @@
package route package route
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"time" "time"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields" T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
@ -33,7 +35,7 @@ var NewUDPConnMap = F.NewMap[UDPConnMap]
const udpBufferSize = 8192 const udpBufferSize = 8192
func NewUDPRoute(base *StreamRoute) StreamImpl { func NewUDPRoute(base *StreamRoute) *UDPRoute {
return &UDPRoute{ return &UDPRoute{
StreamRoute: base, StreamRoute: base,
connMap: NewUDPConnMap(), connMap: NewUDPConnMap(),
@ -64,7 +66,7 @@ func (route *UDPRoute) Setup() error {
return nil return nil
} }
func (route *UDPRoute) Accept() (any, error) { func (route *UDPRoute) Accept() (types.StreamConn, error) {
in := route.listeningConn in := route.listeningConn
buffer := make([]byte, udpBufferSize) buffer := make([]byte, udpBufferSize)
@ -104,7 +106,7 @@ func (route *UDPRoute) Accept() (any, error) {
return conn, err return conn, err
} }
func (route *UDPRoute) Handle(c any) error { func (route *UDPRoute) Handle(c types.StreamConn) error {
conn := c.(*UDPConn) conn := c.(*UDPConn)
err := conn.Start() err := conn.Start()
route.connMap.Delete(conn.key) route.connMap.Delete(conn.key)
@ -114,19 +116,25 @@ func (route *UDPRoute) Handle(c any) error {
func (route *UDPRoute) CloseListeners() { func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil { if route.listeningConn != nil {
route.listeningConn.Close() route.listeningConn.Close()
route.listeningConn = nil
} }
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) { route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil { if err := conn.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err) route.l.Errorf("error closing conn: %s", err)
}
if err := conn.dst.Close(); err != nil {
route.l.Error("error closing dst conn: %s", err)
} }
}) })
route.connMap.Clear() route.connMap.Clear()
} }
// Close implements types.StreamConn
func (conn *UDPConn) Close() error {
return errors.Join(conn.src.Close(), conn.dst.Close())
}
// RemoteAddr implements types.StreamConn
func (conn *UDPConn) RemoteAddr() net.Addr {
return conn.src.RemoteAddr()
}
type sourceRWCloser struct { type sourceRWCloser struct {
server *net.UDPConn server *net.UDPConn
*net.UDPConn *net.UDPConn

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"log" "log"
@ -9,8 +10,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/task"
"golang.org/x/net/context"
) )
type Server struct { type Server struct {
@ -21,7 +21,8 @@ type Server struct {
httpStarted bool httpStarted bool
httpsStarted bool httpsStarted bool
startTime time.Time startTime time.Time
task common.Task
task task.Task
} }
type Options struct { type Options struct {
@ -84,7 +85,7 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider, CertProvider: opt.CertProvider,
http: httpSer, http: httpSer,
https: httpsSer, https: httpsSer,
task: common.GlobalTask(opt.Name + " server"), task: task.GlobalTask(opt.Name + " server"),
} }
} }
@ -115,11 +116,7 @@ func (s *Server) Start() {
}() }()
} }
go func() { s.task.OnComplete("stop server", s.stop)
<-s.task.Context().Done()
s.stop()
s.task.Finished()
}()
} }
func (s *Server) stop() { func (s *Server) stop() {
@ -127,16 +124,13 @@ func (s *Server) stop() {
return return
} }
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if s.http != nil && s.httpStarted { if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(ctx)) s.handleErr("http", s.http.Shutdown(s.task.Context()))
s.httpStarted = false s.httpStarted = false
} }
if s.https != nil && s.httpsStarted { if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(ctx)) s.handleErr("https", s.https.Shutdown(s.task.Context()))
s.httpsStarted = false s.httpsStarted = false
} }
} }
@ -147,7 +141,7 @@ func (s *Server) Uptime() time.Duration {
func (s *Server) handleErr(scheme string, err error) { func (s *Server) handleErr(scheme string, err error) {
switch { switch {
case err == nil, errors.Is(err, http.ErrServerClosed): case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
return return
default: default:
logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err) logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err)

View file

@ -0,0 +1,43 @@
package task
import "context"
type dummyTask struct{}
func DummyTask() (_ Task) {
return
}
// Context implements Task.
func (d dummyTask) Context() context.Context {
panic("call of dummyTask.Context")
}
// Finish implements Task.
func (d dummyTask) Finish() {}
// Name implements Task.
func (d dummyTask) Name() string {
return "Dummy Task"
}
// OnComplete implements Task.
func (d dummyTask) OnComplete(about string, fn func()) {
panic("call of dummyTask.OnComplete")
}
// Parent implements Task.
func (d dummyTask) Parent() Task {
panic("call of dummyTask.Parent")
}
// Subtask implements Task.
func (d dummyTask) Subtask(usageFmt string, args ...any) Task {
panic("call of dummyTask.Subtask")
}
// Wait implements Task.
func (d dummyTask) Wait() {}
// WaitSubTasks implements Task.
func (d dummyTask) WaitSubTasks() {}

310
internal/task/task.go Normal file
View file

@ -0,0 +1,310 @@
package task
import (
"context"
"errors"
"fmt"
"runtime"
"strings"
"sync"
"time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
var globalTask = createGlobalTask()
func createGlobalTask() (t *task) {
t = new(task)
t.name = "root"
t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.subtasks = xsync.NewMapOf[*task, struct{}]()
return
}
type (
// Task controls objects' lifetime.
//
// Task must be initialized, use DummyTask if the task is not yet started.
//
// Objects that uses a task should implement the TaskStarter and the TaskFinisher interface.
//
// When passing a Task object to another function,
// it must be a sub-task of the current task,
// in name of "`currentTaskName`Subtask"
//
// Use Task.Finish to stop all subtasks of the task.
Task interface {
TaskFinisher
// Name returns the name of the task.
Name() string
// Context returns the context associated with the task. This context is
// canceled when Finish is called.
Context() context.Context
// FinishCause returns the reason / error that caused the task to be finished.
FinishCause() error
// Parent returns the parent task of the current task.
Parent() Task
// Subtask returns a new subtask with the given name, derived from the parent's context.
//
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
//
// This should not be called after Finish, Wait, or WaitSubTasks is called.
Subtask(usageFmt string, args ...any) Task
// OnComplete calls fn when the task and all subtasks are finished.
//
// It cannot be called after Finish or Wait is called.
OnComplete(about string, fn func())
// Wait waits for all subtasks, itself and all OnComplete to finish.
//
// It must be called only after Finish is called.
Wait()
// WaitSubTasks waits for all subtasks of the task to finish.
//
// No more subtasks can be added after this call.
//
// It can be called before Finish is called.
WaitSubTasks()
}
TaskStarter interface {
// Start starts the object that implements TaskStarter,
// and returns an error if it fails to start.
//
// The task passed must be a subtask of the caller task.
//
// callerSubtask.Finish must be called when start fails or the object is finished.
Start(callerSubtask Task) E.NestedError
}
TaskFinisher interface {
// Finish marks the task as finished by cancelling its context.
//
// Then call Wait to wait for all subtasks and OnComplete of the task to finish.
//
// Note that it will also cancel all subtasks.
Finish(reason string)
}
task struct {
ctx context.Context
cancel context.CancelCauseFunc
parent *task
subtasks *xsync.MapOf[*task, struct{}]
name, line string
subTasksWg, onCompleteWg sync.WaitGroup
}
)
var (
ErrProgramExiting = errors.New("program exiting")
ErrTaskCancelled = errors.New("task cancelled")
)
// GlobalTask returns a new Task with the given name, derived from the global context.
func GlobalTask(format string, args ...any) Task {
return globalTask.Subtask(format, args...)
}
// DebugTaskMap returns a map[string]any representation of the global task tree.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func DebugTaskMap() map[string]any {
return globalTask.serialize()
}
// CancelGlobalContext cancels the global task context, which will cause all tasks
// created to be canceled. This should be called before exiting the program
// to ensure that all tasks are properly cleaned up.
func CancelGlobalContext() {
globalTask.cancel(ErrProgramExiting)
}
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) {
done := make(chan struct{})
after := time.After(timeout)
go func() {
globalTask.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
logrus.Warn("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
return
}
}
}
func (t *task) Name() string {
return t.name
}
func (t *task) Context() context.Context {
return t.ctx
}
func (t *task) FinishCause() error {
return context.Cause(t.ctx)
}
func (t *task) Parent() Task {
return t.parent
}
func (t *task) OnComplete(about string, fn func()) {
t.onCompleteWg.Add(1)
var file string
var line int
if common.IsTrace {
_, file, line, _ = runtime.Caller(1)
}
go func() {
defer func() {
if err := recover(); err != nil {
logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err)
}
}()
defer t.onCompleteWg.Done()
t.subTasksWg.Wait()
<-t.ctx.Done()
fn()
logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about)
t.cancel(nil) // ensure resources are released
}()
}
func (t *task) Finish(reason string) {
t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason))
t.Wait()
}
func (t *task) Subtask(format string, args ...any) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
ctx, cancel := context.WithCancelCause(t.ctx)
return t.newSubTask(ctx, cancel, format)
}
func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task {
parent := t
subtask := &task{
ctx: ctx,
cancel: cancel,
name: name,
parent: parent,
subtasks: xsync.NewMapOf[*task, struct{}](),
}
parent.subTasksWg.Add(1)
parent.subtasks.Store(subtask, struct{}{})
if common.IsTrace {
_, file, line, ok := runtime.Caller(3)
if ok {
subtask.line = fmt.Sprintf("%s:%d", file, line)
}
logrus.Tracef("line %s\ntask %q started", subtask.line, name)
go func() {
subtask.Wait()
logrus.Tracef("task %q finished", subtask.Name())
}()
}
go func() {
subtask.Wait()
parent.subtasks.Delete(subtask)
parent.subTasksWg.Done()
}()
return subtask
}
func (t *task) Wait() {
t.subTasksWg.Wait()
if t != globalTask {
<-t.ctx.Done()
}
t.onCompleteWg.Wait()
}
func (t *task) WaitSubTasks() {
t.subTasksWg.Wait()
}
// tree returns a string representation of the task tree, with the given
// prefix prepended to each line. The prefix is used to indent the tree,
// and should be a string of spaces or a similar separator.
//
// The resulting string is suitable for printing to the console, and can be
// used to debug the task tree.
//
// The tree is traversed in a depth-first manner, with each task's name and
// line number (if available) printed on a separate line. The line number is
// only printed if the task was created with a non-empty line argument.
//
// The returned string is not guaranteed to be stable, and may change between
// runs of the program. It is intended for debugging purposes only.
func (t *task) tree(prefix ...string) string {
var sb strings.Builder
var pre string
if len(prefix) > 0 {
pre = prefix[0]
sb.WriteString(pre + "- ")
}
if t.line != "" {
sb.WriteString("line " + t.line + "\n")
}
if len(pre) > 0 {
sb.WriteString(pre + "- ")
}
sb.WriteString(t.Name() + "\n")
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
sb.WriteString(subtask.tree(pre + " "))
return true
})
return sb.String()
}
// serialize returns a map[string]any representation of the task tree.
//
// The map contains the following keys:
// - name: the name of the task
// - line: the line number of the task, if available
// - subtasks: a slice of maps, each representing a subtask
//
// The subtask maps contain the same keys, recursively.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func (t *task) serialize() map[string]any {
m := make(map[string]any)
m["name"] = t.name
if t.line != "" {
m["line"] = t.line
}
if t.subtasks.Size() > 0 {
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
return true
})
}
return m
}

147
internal/task/task_test.go Normal file
View file

@ -0,0 +1,147 @@
package task_test
import (
"context"
"sync/atomic"
"testing"
"time"
. "github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestTaskCreation(t *testing.T) {
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
ExpectEqual(t, "root-task", rootTask.Name())
ExpectEqual(t, "subtask", subTask.Name())
}
func TestTaskCancellation(t *testing.T) {
defer CancelGlobalContext()
subTaskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
go func() {
subTask.Wait()
close(subTaskDone)
}()
go rootTask.Finish("done")
select {
case <-subTaskDone:
err := subTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(subTask.Context())
ExpectError(t, ErrTaskCancelled, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
}
}
func TestGlobalContextCancellation(t *testing.T) {
taskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
go func() {
rootTask.Wait()
close(taskDone)
}()
CancelGlobalContext()
select {
case <-taskDone:
err := rootTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(rootTask.Context())
ExpectError(t, ErrProgramExiting, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
}
}
func TestOnComplete(t *testing.T) {
defer CancelGlobalContext()
task := GlobalTask("test")
var value atomic.Int32
task.OnComplete("set value", func() {
value.Store(1234)
})
task.Finish("done")
ExpectEqual(t, value.Load(), 1234)
}
func TestGlobalContextWait(t *testing.T) {
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
finished1, finished2 := false, false
subTask1 := rootTask.Subtask("subtask1")
subTask2 := rootTask.Subtask("subtask2")
subTask1.OnComplete("set finished", func() {
finished1 = true
})
subTask2.OnComplete("set finished", func() {
finished2 = true
})
go func() {
time.Sleep(500 * time.Millisecond)
subTask1.Finish("done")
}()
go func() {
time.Sleep(500 * time.Millisecond)
subTask2.Finish("done")
}()
go func() {
subTask1.Wait()
subTask2.Wait()
rootTask.Finish("done")
}()
GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context()))
}
func TestTimeoutOnGlobalContextWait(t *testing.T) {
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
done := make(chan struct{})
go func() {
GlobalContextWait(500 * time.Millisecond)
close(done)
}()
select {
case <-done:
t.Fatal("GlobalContextWait should have timed out")
case <-time.After(200 * time.Millisecond):
}
// Ensure clean exit
subTask.Finish("exit")
}
func TestGlobalContextCancel(t *testing.T) {
}

View file

@ -1,18 +0,0 @@
package types
type Config struct {
Providers ProxyProviders `json:"providers" yaml:",flow"`
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
}
func DefaultConfig() *Config {
return &Config{
Providers: ProxyProviders{},
TimeoutShutdown: 3,
RedirectToHTTPS: false,
}
}

View file

@ -1,6 +0,0 @@
package types
type ProxyProviders struct {
Files []string `json:"include" yaml:"include"` // docker, file
Docker map[string]string `json:"docker" yaml:"docker"`
}

View file

@ -23,7 +23,7 @@ func IgnoreError[Result any](r Result, _ error) Result {
func ExpectNoError(t *testing.T, err error) { func ExpectNoError(t *testing.T, err error) {
t.Helper() t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() { if err != nil && !reflect.ValueOf(err).IsNil() {
t.Errorf("expected err=nil, got %s", err.Error()) t.Errorf("expected err=nil, got %s", err)
t.FailNow() t.FailNow()
} }
} }
@ -31,7 +31,7 @@ func ExpectNoError(t *testing.T, err error) {
func ExpectError(t *testing.T, expected error, err error) { func ExpectError(t *testing.T, expected error, err error) {
t.Helper() t.Helper()
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Errorf("expected err %s, got %s", expected.Error(), err.Error()) t.Errorf("expected err %s, got %s", expected, err)
t.FailNow() t.FailNow()
} }
} }
@ -39,7 +39,7 @@ func ExpectError(t *testing.T, expected error, err error) {
func ExpectError2(t *testing.T, input any, expected error, err error) { func ExpectError2(t *testing.T, input any, expected error, err error) {
t.Helper() t.Helper()
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Errorf("%v: expected err %s, got %s", input, expected.Error(), err.Error()) t.Errorf("%v: expected err %s, got %s", input, expected, err)
t.FailNow() t.FailNow()
} }
} }

View file

@ -15,8 +15,9 @@ import (
type ( type (
DockerWatcher struct { DockerWatcher struct {
host string host string
client D.Client client D.Client
clientOwned bool
logrus.FieldLogger logrus.FieldLogger
} }
DockerListOptions = docker_events.ListOptions DockerListOptions = docker_events.ListOptions
@ -44,10 +45,11 @@ func DockerrFilterContainer(nameOrID string) filters.KeyValuePair {
func NewDockerWatcher(host string) DockerWatcher { func NewDockerWatcher(host string) DockerWatcher {
return DockerWatcher{ return DockerWatcher{
host: host,
clientOwned: true,
FieldLogger: (logrus. FieldLogger: (logrus.
WithField("module", "docker_watcher"). WithField("module", "docker_watcher").
WithField("host", host)), WithField("host", host)),
host: host,
} }
} }
@ -72,7 +74,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
defer close(errCh) defer close(errCh)
defer func() { defer func() {
if w.client.Connected() { if w.clientOwned && w.client.Connected() {
w.client.Close() w.client.Close()
} }
}() }()

View file

@ -0,0 +1,91 @@
package events
import (
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
)
type (
EventQueue struct {
task task.Task
queue []Event
ticker *time.Ticker
onFlush OnFlushFunc
onError OnErrorFunc
}
OnFlushFunc = func(flushTask task.Task, events []Event)
OnErrorFunc = func(err E.NestedError)
)
const eventQueueCapacity = 10
// NewEventQueue returns a new EventQueue with the given
// queueTask, flushInterval, onFlush and onError.
//
// The returned EventQueue will start a goroutine to flush events in the queue
// when the flushInterval is reached.
//
// The onFlush function is called when the flushInterval is reached and the queue is not empty,
//
// The onError function is called when an error received from the errCh,
// or panic occurs in the onFlush function. Panic will cause a E.ErrPanicRecv error.
//
// flushTask.Finish must be called after the flush is done,
// but the onFlush function can return earlier (e.g. run in another goroutine).
//
// If task is cancelled before the flushInterval is reached, the events in queue will be discarded.
func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
return &EventQueue{
task: parent.Subtask("event queue"),
queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval),
onFlush: onFlush,
onError: onError,
}
}
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) {
go func() {
defer e.ticker.Stop()
for {
select {
case <-e.task.Context().Done():
e.task.Finish(e.task.FinishCause().Error())
return
case <-e.ticker.C:
if len(e.queue) > 0 {
flushTask := e.task.Subtask("flush events")
queue := e.queue
e.queue = make([]Event, 0, eventQueueCapacity)
go func() {
defer func() {
if err := recover(); err != nil {
e.onError(E.PanicRecv("panic in onFlush %s", err))
}
}()
e.onFlush(flushTask, queue)
}()
flushTask.Wait()
}
case event, ok := <-eventCh:
e.queue = append(e.queue, event)
if !ok {
return
}
case err := <-errCh:
if err != nil {
e.onError(err)
}
}
}
}()
}
// Wait waits for all events to be flushed and the task to finish.
//
// It is safe to call this method multiple times.
func (e *EventQueue) Wait() {
e.task.Wait()
}

View file

@ -74,7 +74,7 @@ var actionNameMap = func() (m map[Action]string) {
}() }()
func (e Event) String() string { func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action) return fmt.Sprintf("%s %s", e.Action, e.ActorName)
} }
func (a Action) String() string { func (a Action) String() string {

View file

@ -5,7 +5,6 @@ import (
"errors" "errors"
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
@ -15,10 +14,10 @@ type HTTPHealthMonitor struct {
pinger *http.Client pinger *http.Client
} }
func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor { func NewHTTPHealthMonitor(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) *HTTPHealthMonitor {
mon := new(HTTPHealthMonitor) mon := new(HTTPHealthMonitor)
mon.monitor = newMonitor(task, url, config, mon.checkHealth) mon.monitor = newMonitor(url, config, mon.CheckHealth)
mon.pinger = &http.Client{Timeout: config.Timeout} mon.pinger = &http.Client{Timeout: config.Timeout, Transport: transport}
if config.UseGet { if config.UseGet {
mon.method = http.MethodGet mon.method = http.MethodGet
} else { } else {
@ -27,19 +26,26 @@ func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckCo
return mon return mon
} }
func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) { func NewHTTPHealthChecker(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) HealthChecker {
return NewHTTPHealthMonitor(url, config, transport)
}
func (mon *HTTPHealthMonitor) CheckHealth() (healthy bool, detail string, err error) {
ctx, cancel := mon.ContextWithTimeout("ping request timed out")
defer cancel()
req, reqErr := http.NewRequestWithContext( req, reqErr := http.NewRequestWithContext(
mon.task.Context(), ctx,
mon.method, mon.method,
mon.url.JoinPath(mon.config.Path).String(), mon.url.Load().JoinPath(mon.config.Path).String(),
nil, nil,
) )
if reqErr != nil { if reqErr != nil {
err = reqErr err = reqErr
return return
} }
req.Header.Set("Connection", "close")
req.Header.Set("Connection", "close")
resp, respErr := mon.pinger.Do(req) resp, respErr := mon.pinger.Do(req)
if respErr == nil { if respErr == nil {
resp.Body.Close() resp.Body.Close()

View file

@ -2,78 +2,93 @@ package health
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"sync" "fmt"
"time" "time"
"github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
type ( type (
HealthMonitor interface { HealthMonitor interface {
Start() task.TaskStarter
Stop() task.TaskFinisher
fmt.Stringer
json.Marshaler
Status() Status Status() Status
Uptime() time.Duration Uptime() time.Duration
Name() string Name() string
String() string }
MarshalJSON() ([]byte, error) HealthChecker interface {
CheckHealth() (healthy bool, detail string, err error)
URL() types.URL
Config() *HealthCheckConfig
UpdateURL(url types.URL)
} }
HealthCheckFunc func() (healthy bool, detail string, err error) HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct { monitor struct {
service string service string
config *HealthCheckConfig config *HealthCheckConfig
url types.URL url U.AtomicValue[types.URL]
status U.AtomicValue[Status] status U.AtomicValue[Status]
checkHealth HealthCheckFunc checkHealth HealthCheckFunc
startTime time.Time startTime time.Time
task common.Task task task.Task
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
} }
) )
var monMap = F.NewMapOf[string, HealthMonitor]() var monMap = F.NewMapOf[string, HealthMonitor]()
func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
service := task.Name()
task, cancel := task.SubtaskWithCancel("Health monitor for %s", service)
mon := &monitor{ mon := &monitor{
service: service,
config: config, config: config,
url: url,
checkHealth: healthCheckFunc, checkHealth: healthCheckFunc,
startTime: time.Now(), startTime: time.Now(),
task: task, task: task.DummyTask(),
cancel: cancel,
done: make(chan struct{}),
} }
mon.url.Store(url)
mon.status.Store(StatusHealthy) mon.status.Store(StatusHealthy)
return mon return mon
} }
func Inspect(name string) (HealthMonitor, bool) { func Inspect(service string) (HealthMonitor, bool) {
return monMap.Load(name) return monMap.Load(service)
} }
func (mon *monitor) Start() { func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cancel context.CancelFunc) {
defer monMap.Store(mon.task.Name(), mon) if mon.task != nil {
return context.WithTimeoutCause(mon.task.Context(), mon.config.Timeout, errors.New(cause))
} else {
return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause))
}
}
// Start implements task.TaskStarter.
func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
mon.service = routeSubtask.Parent().Name()
mon.task = routeSubtask
if err := mon.checkUpdateHealth(); err != nil {
mon.task.Finish(fmt.Sprintf("healthchecker %s failure: %s", mon.service, err))
return err
}
go func() { go func() {
defer close(mon.done) defer func() {
defer mon.task.Finished() monMap.Delete(mon.task.Name())
if mon.status.Load() != StatusError {
mon.status.Store(StatusUnknown)
}
mon.task.Finish(mon.task.FinishCause().Error())
}()
ok := mon.checkUpdateHealth() monMap.Store(mon.service, mon)
if !ok {
return
}
ticker := time.NewTicker(mon.config.Interval) ticker := time.NewTicker(mon.config.Interval)
defer ticker.Stop() defer ticker.Stop()
@ -83,48 +98,61 @@ func (mon *monitor) Start() {
case <-mon.task.Context().Done(): case <-mon.task.Context().Done():
return return
case <-ticker.C: case <-ticker.C:
ok = mon.checkUpdateHealth() err := mon.checkUpdateHealth()
if !ok { if err != nil {
logger.Errorf("healthchecker %s failure: %s", mon.service, err)
return return
} }
} }
} }
}() }()
return nil
} }
func (mon *monitor) Stop() { // Finish implements task.TaskFinisher.
monMap.Delete(mon.task.Name()) func (mon *monitor) Finish(reason string) {
mon.task.Finish(reason)
mon.mu.Lock()
defer mon.mu.Unlock()
if mon.cancel == nil {
return
}
mon.cancel()
<-mon.done
mon.cancel = nil
mon.status.Store(StatusUnknown)
} }
// UpdateURL implements HealthChecker.
func (mon *monitor) UpdateURL(url types.URL) {
mon.url.Store(url)
}
// URL implements HealthChecker.
func (mon *monitor) URL() types.URL {
return mon.url.Load()
}
// Config implements HealthChecker.
func (mon *monitor) Config() *HealthCheckConfig {
return mon.config
}
// Status implements HealthMonitor.
func (mon *monitor) Status() Status { func (mon *monitor) Status() Status {
return mon.status.Load() return mon.status.Load()
} }
// Uptime implements HealthMonitor.
func (mon *monitor) Uptime() time.Duration { func (mon *monitor) Uptime() time.Duration {
return time.Since(mon.startTime) return time.Since(mon.startTime)
} }
// Name implements HealthMonitor.
func (mon *monitor) Name() string { func (mon *monitor) Name() string {
if mon.task == nil {
return ""
}
return mon.task.Name() return mon.task.Name()
} }
// String implements fmt.Stringer of HealthMonitor.
func (mon *monitor) String() string { func (mon *monitor) String() string {
return mon.Name() return mon.Name()
} }
// MarshalJSON implements json.Marshaler of HealthMonitor.
func (mon *monitor) MarshalJSON() ([]byte, error) { func (mon *monitor) MarshalJSON() ([]byte, error) {
return (&JSONRepresentation{ return (&JSONRepresentation{
Name: mon.service, Name: mon.service,
@ -132,19 +160,19 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
Status: mon.status.Load(), Status: mon.status.Load(),
Started: mon.startTime, Started: mon.startTime,
Uptime: mon.Uptime(), Uptime: mon.Uptime(),
URL: mon.url, URL: mon.url.Load(),
}).MarshalJSON() }).MarshalJSON()
} }
func (mon *monitor) checkUpdateHealth() (hasError bool) { func (mon *monitor) checkUpdateHealth() E.NestedError {
healthy, detail, err := mon.checkHealth() healthy, detail, err := mon.checkHealth()
if err != nil { if err != nil {
defer mon.task.Finish(err.Error())
mon.status.Store(StatusError) mon.status.Store(StatusError)
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
logger.Errorf("%s failed to check health: %s", mon.service, err) return E.Failure("check health").With(err)
} }
mon.Stop() return nil
return false
} }
var status Status var status Status
if healthy { if healthy {
@ -160,5 +188,5 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
} }
} }
return true return nil
} }

View file

@ -3,7 +3,6 @@ package health
import ( import (
"net" "net"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
@ -14,9 +13,9 @@ type (
} }
) )
func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor { func NewRawHealthMonitor(url types.URL, config *HealthCheckConfig) *RawHealthMonitor {
mon := new(RawHealthMonitor) mon := new(RawHealthMonitor)
mon.monitor = newMonitor(task, url, config, mon.checkAvail) mon.monitor = newMonitor(url, config, mon.CheckHealth)
mon.dialer = &net.Dialer{ mon.dialer = &net.Dialer{
Timeout: config.Timeout, Timeout: config.Timeout,
FallbackDelay: -1, FallbackDelay: -1,
@ -24,14 +23,22 @@ func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckCon
return mon return mon
} }
func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) { func NewRawHealthChecker(url types.URL, config *HealthCheckConfig) HealthChecker {
conn, dialErr := mon.dialer.DialContext(mon.task.Context(), mon.url.Scheme, mon.url.Host) return NewRawHealthMonitor(url, config)
}
func (mon *RawHealthMonitor) CheckHealth() (healthy bool, detail string, err error) {
ctx, cancel := mon.ContextWithTimeout("ping request timed out")
defer cancel()
url := mon.url.Load()
conn, dialErr := mon.dialer.DialContext(ctx, url.Scheme, url.Host)
if dialErr != nil { if dialErr != nil {
detail = dialErr.Error() detail = dialErr.Error()
/* trunk-ignore(golangci-lint/nilerr) */ /* trunk-ignore(golangci-lint/nilerr) */
return return
} }
conn.Close() conn.Close()
avail = true healthy = true
return return
} }