mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
Fixed a few issues:
- Incorrect name being shown on dashboard "Proxies page" - Apps being shown when homepage.show is false - Load balanced routes are shown on homepage instead of the load balancer - Route with idlewatcher will now be removed on container destroy - Idlewatcher panic - Performance improvement - Idlewatcher infinitely loading - Reload stucked / not working properly - Streams stuck on shutdown / reload - etc... Added: - support idlewatcher for loadbalanced routes - partial implementation for stream type idlewatcher Issues: - graceful shutdown
This commit is contained in:
parent
c0c61709ca
commit
53557e38b6
69 changed files with 2368 additions and 1654 deletions
6
Makefile
6
Makefile
|
@ -30,6 +30,12 @@ get:
|
|||
debug:
|
||||
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||
|
||||
debug-trace:
|
||||
make build && sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
|
||||
|
||||
profile:
|
||||
GODEBUG=gctrace=1 make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
|
||||
|
||||
mtrace:
|
||||
bin/go-proxy debug-ls-mtrace > mtrace.json
|
||||
|
||||
|
|
48
cmd/main.go
48
cmd/main.go
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/server"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
)
|
||||
|
||||
|
@ -32,8 +33,14 @@ func main() {
|
|||
}
|
||||
|
||||
l := logrus.WithField("module", "main")
|
||||
timeFmt := "01-02 15:04:05"
|
||||
fullTS := true
|
||||
|
||||
if common.IsDebug {
|
||||
if common.IsTrace {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
timeFmt = "04:05"
|
||||
fullTS = false
|
||||
} else if common.IsDebug {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
|
@ -42,9 +49,9 @@ func main() {
|
|||
} else {
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
DisableSorting: true,
|
||||
FullTimestamp: true,
|
||||
FullTimestamp: fullTS,
|
||||
ForceColors: true,
|
||||
TimestampFormat: "01-02 15:04:05",
|
||||
TimestampFormat: timeFmt,
|
||||
})
|
||||
logrus.Infof("go-proxy version %s", pkg.GetVersion())
|
||||
}
|
||||
|
@ -76,21 +83,22 @@ func main() {
|
|||
|
||||
middleware.LoadComposeFiles()
|
||||
|
||||
if err := config.Load(); err != nil {
|
||||
var cfg *config.Config
|
||||
var err E.NestedError
|
||||
if cfg, err = config.Load(); err != nil {
|
||||
logrus.Warn(err)
|
||||
}
|
||||
cfg := config.GetInstance()
|
||||
|
||||
switch args.Command {
|
||||
case common.CommandListConfigs:
|
||||
printJSON(cfg.Value())
|
||||
printJSON(config.Value())
|
||||
return
|
||||
case common.CommandListRoutes:
|
||||
routes, err := query.ListRoutes()
|
||||
if err != nil {
|
||||
log.Printf("failed to connect to api server: %s", err)
|
||||
log.Printf("falling back to config file")
|
||||
printJSON(cfg.RoutesByAlias())
|
||||
printJSON(config.RoutesByAlias())
|
||||
} else {
|
||||
printJSON(routes)
|
||||
}
|
||||
|
@ -103,10 +111,10 @@ func main() {
|
|||
printJSON(icons)
|
||||
return
|
||||
case common.CommandDebugListEntries:
|
||||
printJSON(cfg.DumpEntries())
|
||||
printJSON(config.DumpEntries())
|
||||
return
|
||||
case common.CommandDebugListProviders:
|
||||
printJSON(cfg.DumpProviders())
|
||||
printJSON(config.DumpProviders())
|
||||
return
|
||||
case common.CommandDebugListMTrace:
|
||||
trace, err := query.ListMiddlewareTraces()
|
||||
|
@ -114,17 +122,25 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
printJSON(trace)
|
||||
return
|
||||
case common.CommandDebugListTasks:
|
||||
tasks, err := query.DebugListTasks()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
printJSON(tasks)
|
||||
return
|
||||
}
|
||||
|
||||
cfg.StartProxyProviders()
|
||||
cfg.WatchChanges()
|
||||
config.WatchChanges()
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT)
|
||||
signal.Notify(sig, syscall.SIGTERM)
|
||||
signal.Notify(sig, syscall.SIGHUP)
|
||||
|
||||
autocert := cfg.GetAutoCertProvider()
|
||||
autocert := config.GetAutoCertProvider()
|
||||
if autocert != nil {
|
||||
if err := autocert.Setup(); err != nil {
|
||||
l.Fatal(err)
|
||||
|
@ -139,14 +155,14 @@ func main() {
|
|||
HTTPAddr: common.ProxyHTTPAddr,
|
||||
HTTPSAddr: common.ProxyHTTPSAddr,
|
||||
Handler: http.HandlerFunc(R.ProxyHandler),
|
||||
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
|
||||
RedirectToHTTPS: config.Value().RedirectToHTTPS,
|
||||
})
|
||||
apiServer := server.InitAPIServer(server.Options{
|
||||
Name: "api",
|
||||
CertProvider: autocert,
|
||||
HTTPAddr: common.APIHTTPAddr,
|
||||
Handler: api.NewHandler(cfg),
|
||||
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
|
||||
Handler: api.NewHandler(),
|
||||
RedirectToHTTPS: config.Value().RedirectToHTTPS,
|
||||
})
|
||||
|
||||
proxyServer.Start()
|
||||
|
@ -157,8 +173,8 @@ func main() {
|
|||
|
||||
// grafully shutdown
|
||||
logrus.Info("shutting down")
|
||||
common.CancelGlobalContext()
|
||||
common.GlobalContextWait(time.Second * time.Duration(cfg.Value().TimeoutShutdown))
|
||||
task.CancelGlobalContext()
|
||||
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
|
||||
}
|
||||
|
||||
func prepareDirectory(dir string) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package api
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
|
@ -21,34 +22,35 @@ func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc
|
|||
mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler))
|
||||
}
|
||||
|
||||
func NewHandler(cfg *config.Config) http.Handler {
|
||||
func NewHandler() http.Handler {
|
||||
mux := NewServeMux()
|
||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||
mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
|
||||
mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
|
||||
mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload))
|
||||
mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List))
|
||||
mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth)
|
||||
mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth)
|
||||
mux.HandleFunc("POST", "/v1/reload", v1.Reload)
|
||||
mux.HandleFunc("GET", "/v1/list", v1.List)
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", v1.List)
|
||||
mux.HandleFunc("GET", "/v1/file", v1.GetFileContent)
|
||||
mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent)
|
||||
mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent)
|
||||
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
|
||||
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats))
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS))
|
||||
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
|
||||
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
|
||||
return mux
|
||||
}
|
||||
|
||||
// allow only requests to API server with host matching common.APIHTTPAddr.
|
||||
// allow only requests to API server with localhost.
|
||||
func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
||||
if common.IsDebug {
|
||||
return f
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host != common.APIHTTPAddr {
|
||||
Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr)
|
||||
http.Error(w, "invalid request", http.StatusForbidden)
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
|
||||
Logger.Warnf("blocked API request from %s", host)
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
f(w, r)
|
||||
|
|
|
@ -4,11 +4,10 @@ import (
|
|||
"net/http"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
func CheckHealth(w http.ResponseWriter, r *http.Request) {
|
||||
target := r.FormValue("target")
|
||||
if target == "" {
|
||||
HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest)
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/provider"
|
||||
"github.com/yusing/go-proxy/internal/route/provider"
|
||||
)
|
||||
|
||||
func GetFileContent(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -9,19 +9,21 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
ListRoutes = "routes"
|
||||
ListConfigFiles = "config_files"
|
||||
ListMiddlewares = "middlewares"
|
||||
ListMiddlewareTrace = "middleware_trace"
|
||||
ListMatchDomains = "match_domains"
|
||||
ListHomepageConfig = "homepage_config"
|
||||
ListRoutes = "routes"
|
||||
ListConfigFiles = "config_files"
|
||||
ListMiddlewares = "middlewares"
|
||||
ListMiddlewareTraces = "middleware_trace"
|
||||
ListMatchDomains = "match_domains"
|
||||
ListHomepageConfig = "homepage_config"
|
||||
ListTasks = "tasks"
|
||||
)
|
||||
|
||||
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
func List(w http.ResponseWriter, r *http.Request) {
|
||||
what := r.PathValue("what")
|
||||
if what == "" {
|
||||
what = ListRoutes
|
||||
|
@ -29,27 +31,24 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
switch what {
|
||||
case ListRoutes:
|
||||
listRoutes(cfg, w, r)
|
||||
U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type"))))
|
||||
case ListConfigFiles:
|
||||
listConfigFiles(w, r)
|
||||
case ListMiddlewares:
|
||||
listMiddlewares(w, r)
|
||||
case ListMiddlewareTrace:
|
||||
listMiddlewareTrace(w, r)
|
||||
U.RespondJSON(w, r, middleware.All())
|
||||
case ListMiddlewareTraces:
|
||||
U.RespondJSON(w, r, middleware.GetAllTrace())
|
||||
case ListMatchDomains:
|
||||
listMatchDomains(cfg, w, r)
|
||||
U.RespondJSON(w, r, config.Value().MatchDomains)
|
||||
case ListHomepageConfig:
|
||||
listHomepageConfig(cfg, w, r)
|
||||
U.RespondJSON(w, r, config.HomepageConfig())
|
||||
case ListTasks:
|
||||
U.RespondJSON(w, r, task.DebugTaskMap())
|
||||
default:
|
||||
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type")))
|
||||
U.RespondJSON(w, r, routes)
|
||||
}
|
||||
|
||||
func listConfigFiles(w http.ResponseWriter, r *http.Request) {
|
||||
files, err := utils.ListFiles(common.ConfigBasePath, 1)
|
||||
if err != nil {
|
||||
|
@ -61,19 +60,3 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
U.RespondJSON(w, r, files)
|
||||
}
|
||||
|
||||
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, middleware.GetAllTrace())
|
||||
}
|
||||
|
||||
func listMiddlewares(w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, middleware.All())
|
||||
}
|
||||
|
||||
func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, cfg.Value().MatchDomains)
|
||||
}
|
||||
|
||||
func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, cfg.HomepageConfig())
|
||||
}
|
||||
|
|
|
@ -34,36 +34,34 @@ func ReloadServer() E.NestedError {
|
|||
return nil
|
||||
}
|
||||
|
||||
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
|
||||
func List[T any](what string) (_ T, outErr E.NestedError) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
outErr = E.From(err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, E.Failure("list routes").Extraf("status code: %v", resp.StatusCode)
|
||||
outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
var routes map[string]map[string]any
|
||||
err = json.NewDecoder(resp.Body).Decode(&routes)
|
||||
var res T
|
||||
err = json.NewDecoder(resp.Body).Decode(&res)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
outErr = E.From(err)
|
||||
return
|
||||
}
|
||||
return routes, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||
return List[map[string]map[string]any](v1.ListRoutes)
|
||||
}
|
||||
|
||||
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
|
||||
}
|
||||
var traces middleware.Traces
|
||||
err = json.NewDecoder(resp.Body).Decode(&traces)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
return traces, nil
|
||||
return List[middleware.Traces](v1.ListMiddlewareTraces)
|
||||
}
|
||||
|
||||
func DebugListTasks() (map[string]any, E.NestedError) {
|
||||
return List[map[string]any](v1.ListTasks)
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/config"
|
||||
)
|
||||
|
||||
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
if err := cfg.Reload(); err != nil {
|
||||
func Reload(w http.ResponseWriter, r *http.Request) {
|
||||
if err := config.Reload(); err != nil {
|
||||
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
|
@ -14,19 +14,19 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, getStats(cfg))
|
||||
func Stats(w http.ResponseWriter, r *http.Request) {
|
||||
U.RespondJSON(w, r, getStats())
|
||||
}
|
||||
|
||||
func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
||||
func StatsWS(w http.ResponseWriter, r *http.Request) {
|
||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||
originPats := make([]string, len(cfg.Value().MatchDomains)+len(localAddresses))
|
||||
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
|
||||
|
||||
if len(originPats) == 0 {
|
||||
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
|
||||
originPats = []string{"*"}
|
||||
} else {
|
||||
for i, domain := range cfg.Value().MatchDomains {
|
||||
for i, domain := range config.Value().MatchDomains {
|
||||
originPats[i] = "*." + domain
|
||||
}
|
||||
originPats = append(originPats, localAddresses...)
|
||||
|
@ -51,7 +51,7 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
|||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
stats := getStats(cfg)
|
||||
stats := getStats()
|
||||
if err := wsjson.Write(ctx, conn, stats); err != nil {
|
||||
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
|
||||
return
|
||||
|
@ -59,9 +59,9 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func getStats(cfg *config.Config) map[string]any {
|
||||
func getStats() map[string]any {
|
||||
return map[string]any{
|
||||
"proxies": cfg.Statistics(),
|
||||
"proxies": config.Statistics(),
|
||||
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/go-acme/lego/v4/lego"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
)
|
||||
|
||||
type Config types.AutoCertConfig
|
||||
|
|
|
@ -13,9 +13,9 @@ import (
|
|||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -140,23 +140,22 @@ func (p *Provider) ScheduleRenewal() {
|
|||
if p.GetName() == ProviderLocal {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
task := common.NewTask("cert renew scheduler")
|
||||
defer task.Finished()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case <-ticker.C: // check every 5 seconds
|
||||
if err := p.renewIfNeeded(); err.HasError() {
|
||||
logger.Warn(err)
|
||||
go func() {
|
||||
task := task.GlobalTask("cert renew scheduler")
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
defer task.Finish("cert renew scheduler stopped")
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case <-ticker.C: // check every 5 seconds
|
||||
if err := p.renewIfNeeded(); err.HasError() {
|
||||
logger.Warn(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Provider) initClient() E.NestedError {
|
||||
|
|
|
@ -17,7 +17,7 @@ func (p *Provider) Setup() (err E.NestedError) {
|
|||
}
|
||||
}
|
||||
|
||||
go p.ScheduleRenewal()
|
||||
p.ScheduleRenewal()
|
||||
|
||||
for _, expiry := range p.GetExpiries() {
|
||||
logger.Infof("certificate expire on %s", expiry)
|
||||
|
|
|
@ -22,6 +22,7 @@ const (
|
|||
CommandDebugListEntries = "debug-ls-entries"
|
||||
CommandDebugListProviders = "debug-ls-providers"
|
||||
CommandDebugListMTrace = "debug-ls-mtrace"
|
||||
CommandDebugListTasks = "debug-ls-tasks"
|
||||
)
|
||||
|
||||
var ValidCommands = []string{
|
||||
|
@ -35,6 +36,7 @@ var ValidCommands = []string{
|
|||
CommandDebugListEntries,
|
||||
CommandDebugListProviders,
|
||||
CommandDebugListMTrace,
|
||||
CommandDebugListTasks,
|
||||
}
|
||||
|
||||
func GetArgs() Args {
|
||||
|
|
|
@ -43,7 +43,6 @@ const (
|
|||
HealthCheckIntervalDefault = 5 * time.Second
|
||||
HealthCheckTimeoutDefault = 5 * time.Second
|
||||
|
||||
IdleTimeoutDefault = "0"
|
||||
WakeTimeoutDefault = "30s"
|
||||
StopTimeoutDefault = "10s"
|
||||
StopMethodDefault = "stop"
|
||||
|
|
|
@ -15,6 +15,7 @@ var (
|
|||
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true)
|
||||
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
|
||||
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
|
||||
IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug
|
||||
|
||||
ProxyHTTPAddr,
|
||||
ProxyHTTPHost,
|
||||
|
|
|
@ -1,224 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
globalCtx, globalCtxCancel = context.WithCancel(context.Background())
|
||||
taskWg sync.WaitGroup
|
||||
tasksMap = xsync.NewMapOf[*task, struct{}]()
|
||||
)
|
||||
|
||||
type (
|
||||
Task interface {
|
||||
Name() string
|
||||
Context() context.Context
|
||||
Subtask(usageFmt string, args ...interface{}) Task
|
||||
SubtaskWithCancel(usageFmt string, args ...interface{}) (Task, context.CancelFunc)
|
||||
Finished()
|
||||
}
|
||||
task struct {
|
||||
ctx context.Context
|
||||
subtasks []*task
|
||||
name string
|
||||
finished bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
func (t *task) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Context returns the context associated with the task. This context is
|
||||
// canceled when the task is finished.
|
||||
func (t *task) Context() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
// Finished marks the task as finished and notifies the global wait group.
|
||||
// Finished is thread-safe and idempotent.
|
||||
func (t *task) Finished() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.finished {
|
||||
return
|
||||
}
|
||||
t.finished = true
|
||||
if _, ok := tasksMap.Load(t); ok {
|
||||
taskWg.Done()
|
||||
tasksMap.Delete(t)
|
||||
}
|
||||
logrus.Debugf("task %q finished", t.Name())
|
||||
}
|
||||
|
||||
// Subtask returns a new subtask with the given name, derived from the receiver's context.
|
||||
//
|
||||
// The returned subtask is associated with the receiver's context and will be
|
||||
// automatically registered and deregistered from the global task wait group.
|
||||
//
|
||||
// If the receiver's context is already canceled, the returned subtask will be
|
||||
// canceled immediately.
|
||||
//
|
||||
// The returned subtask is safe for concurrent use.
|
||||
func (t *task) Subtask(format string, args ...interface{}) Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
sub := newSubTask(t.ctx, format)
|
||||
t.subtasks = append(t.subtasks, sub)
|
||||
return sub
|
||||
}
|
||||
|
||||
// SubtaskWithCancel returns a new subtask with the given name, derived from the receiver's context,
|
||||
// and a cancel function. The returned subtask is associated with the receiver's context and will be
|
||||
// automatically registered and deregistered from the global task wait group.
|
||||
//
|
||||
// If the receiver's context is already canceled, the returned subtask will be canceled immediately.
|
||||
//
|
||||
// The returned cancel function is safe for concurrent use, and can be used to cancel the returned
|
||||
// subtask at any time.
|
||||
func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
ctx, cancel := context.WithCancel(t.ctx)
|
||||
sub := newSubTask(ctx, format)
|
||||
t.subtasks = append(t.subtasks, sub)
|
||||
return sub, cancel
|
||||
}
|
||||
|
||||
func (t *task) tree(prefix ...string) string {
|
||||
var sb strings.Builder
|
||||
var pre string
|
||||
if len(prefix) > 0 {
|
||||
pre = prefix[0]
|
||||
}
|
||||
sb.WriteString(pre)
|
||||
sb.WriteString(t.Name() + "\n")
|
||||
for _, sub := range t.subtasks {
|
||||
if sub.finished {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(sub.tree(pre + " "))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func newSubTask(ctx context.Context, name string) *task {
|
||||
t := &task{
|
||||
ctx: ctx,
|
||||
name: name,
|
||||
}
|
||||
tasksMap.Store(t, struct{}{})
|
||||
taskWg.Add(1)
|
||||
logrus.Debugf("task %q started", name)
|
||||
return t
|
||||
}
|
||||
|
||||
// NewTask returns a new Task with the given name, derived from the global
|
||||
// context.
|
||||
//
|
||||
// The returned Task is associated with the global context and will be
|
||||
// automatically registered and deregistered from the global context's wait
|
||||
// group.
|
||||
//
|
||||
// If the global context is already canceled, the returned Task will be
|
||||
// canceled immediately.
|
||||
//
|
||||
// The returned Task is not safe for concurrent use.
|
||||
func NewTask(format string, args ...interface{}) Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
return newSubTask(globalCtx, format)
|
||||
}
|
||||
|
||||
// NewTaskWithCancel returns a new Task with the given name, derived from the
|
||||
// global context, and a cancel function. The returned Task is associated with
|
||||
// the global context and will be automatically registered and deregistered
|
||||
// from the global task wait group.
|
||||
//
|
||||
// If the global context is already canceled, the returned Task will be
|
||||
// canceled immediately.
|
||||
//
|
||||
// The returned Task is safe for concurrent use.
|
||||
//
|
||||
// The returned cancel function is safe for concurrent use, and can be used
|
||||
// to cancel the returned Task at any time.
|
||||
func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
|
||||
subCtx, cancel := context.WithCancel(globalCtx)
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
return newSubTask(subCtx, format), cancel
|
||||
}
|
||||
|
||||
// GlobalTask returns a new Task with the given name, associated with the
|
||||
// global context.
|
||||
//
|
||||
// Unlike NewTask, GlobalTask does not automatically register or deregister
|
||||
// the Task with the global task wait group. The returned Task is not
|
||||
// started, but the name is formatted immediately.
|
||||
//
|
||||
// This is best used for main task that do not need to wait and
|
||||
// will create a bunch of subtasks.
|
||||
//
|
||||
// The returned Task is safe for concurrent use.
|
||||
func GlobalTask(format string, args ...interface{}) Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
return &task{
|
||||
ctx: globalCtx,
|
||||
name: format,
|
||||
}
|
||||
}
|
||||
|
||||
// CancelGlobalContext cancels the global context, which will cause all tasks
|
||||
// created by GlobalTask or NewTask to be canceled. This should be called
|
||||
// before exiting the program to ensure that all tasks are properly cleaned
|
||||
// up.
|
||||
func CancelGlobalContext() {
|
||||
globalCtxCancel()
|
||||
}
|
||||
|
||||
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GlobalContextWait(timeout time.Duration) {
|
||||
done := make(chan struct{})
|
||||
after := time.After(timeout)
|
||||
go func() {
|
||||
taskWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-after:
|
||||
logrus.Warnln("Timeout waiting for these tasks to finish:")
|
||||
tasksMap.Range(func(t *task, _ struct{}) bool {
|
||||
logrus.Warnln(t.tree())
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,51 +2,66 @@ package config
|
|||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
PR "github.com/yusing/go-proxy/internal/proxy/provider"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
value *types.Config
|
||||
proxyProviders F.Map[string, *PR.Provider]
|
||||
providers F.Map[string, *proxy.Provider]
|
||||
autocertProvider *autocert.Provider
|
||||
|
||||
l logrus.FieldLogger
|
||||
|
||||
watcher W.Watcher
|
||||
|
||||
reloadReq chan struct{}
|
||||
task task.Task
|
||||
}
|
||||
|
||||
var instance *Config
|
||||
var (
|
||||
instance *Config
|
||||
cfgWatcher watcher.Watcher
|
||||
logger = logrus.WithField("module", "config")
|
||||
reloadMu sync.Mutex
|
||||
)
|
||||
|
||||
const configEventFlushInterval = 500 * time.Millisecond
|
||||
|
||||
const (
|
||||
cfgRenameWarn = `Config file renamed, not reloading.
|
||||
Make sure you rename it back before next time you start.`
|
||||
cfgDeleteWarn = `Config file deleted, not reloading.
|
||||
You may run "ls-config" to show or dump the current config.`
|
||||
)
|
||||
|
||||
func GetInstance() *Config {
|
||||
return instance
|
||||
}
|
||||
|
||||
func Load() E.NestedError {
|
||||
func newConfig() *Config {
|
||||
return &Config{
|
||||
value: types.DefaultConfig(),
|
||||
providers: F.NewMapOf[string, *proxy.Provider](),
|
||||
task: task.GlobalTask("config"),
|
||||
}
|
||||
}
|
||||
|
||||
func Load() (*Config, E.NestedError) {
|
||||
if instance != nil {
|
||||
return nil
|
||||
return instance, nil
|
||||
}
|
||||
instance = &Config{
|
||||
value: types.DefaultConfig(),
|
||||
proxyProviders: F.NewMapOf[string, *PR.Provider](),
|
||||
l: logrus.WithField("module", "config"),
|
||||
watcher: W.NewConfigFileWatcher(common.ConfigFileName),
|
||||
reloadReq: make(chan struct{}, 1),
|
||||
}
|
||||
return instance.load()
|
||||
instance = newConfig()
|
||||
cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
|
||||
return instance, instance.load()
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.NestedError {
|
||||
|
@ -54,87 +69,90 @@ func Validate(data []byte) E.NestedError {
|
|||
}
|
||||
|
||||
func MatchDomains() []string {
|
||||
if instance == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return instance.value.MatchDomains
|
||||
}
|
||||
|
||||
func (cfg *Config) Value() types.Config {
|
||||
if cfg == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
}
|
||||
return *cfg.value
|
||||
func WatchChanges() {
|
||||
task := task.GlobalTask("Config watcher")
|
||||
eventQueue := events.NewEventQueue(
|
||||
task,
|
||||
configEventFlushInterval,
|
||||
OnConfigChange,
|
||||
func(err E.NestedError) {
|
||||
logger.Error(err)
|
||||
},
|
||||
)
|
||||
eventQueue.Start(cfgWatcher.Events(task.Context()))
|
||||
}
|
||||
|
||||
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
|
||||
if instance == nil {
|
||||
logrus.Panic("config has not been loaded, please check if there is any errors")
|
||||
func OnConfigChange(flushTask task.Task, ev []events.Event) {
|
||||
defer flushTask.Finish("config reload complete")
|
||||
|
||||
// no matter how many events during the interval
|
||||
// just reload once and check the last event
|
||||
switch ev[len(ev)-1].Action {
|
||||
case events.ActionFileRenamed:
|
||||
logger.Warn(cfgRenameWarn)
|
||||
return
|
||||
case events.ActionFileDeleted:
|
||||
logger.Warn(cfgDeleteWarn)
|
||||
return
|
||||
}
|
||||
|
||||
if err := Reload(); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
return cfg.autocertProvider
|
||||
}
|
||||
|
||||
func (cfg *Config) Reload() (err E.NestedError) {
|
||||
cfg.stopProviders()
|
||||
err = cfg.load()
|
||||
cfg.StartProxyProviders()
|
||||
return
|
||||
func Reload() E.NestedError {
|
||||
// avoid race between config change and API reload request
|
||||
reloadMu.Lock()
|
||||
defer reloadMu.Unlock()
|
||||
|
||||
newCfg := newConfig()
|
||||
err := newCfg.load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cancel all current subtasks -> wait
|
||||
// -> replace config -> start new subtasks
|
||||
instance.task.Finish("config changed")
|
||||
instance.task.Wait()
|
||||
*instance = *newCfg
|
||||
instance.StartProxyProviders()
|
||||
return nil
|
||||
}
|
||||
|
||||
func Value() types.Config {
|
||||
return *instance.value
|
||||
}
|
||||
|
||||
func GetAutoCertProvider() *autocert.Provider {
|
||||
return instance.autocertProvider
|
||||
}
|
||||
|
||||
func (cfg *Config) Task() task.Task {
|
||||
return cfg.task
|
||||
}
|
||||
|
||||
func (cfg *Config) StartProxyProviders() {
|
||||
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
|
||||
}
|
||||
|
||||
func (cfg *Config) WatchChanges() {
|
||||
task := common.NewTask("Config watcher")
|
||||
go func() {
|
||||
defer task.Finished()
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case <-cfg.reloadReq:
|
||||
if err := cfg.Reload(); err != nil {
|
||||
cfg.l.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
eventCh, errCh := cfg.watcher.Events(task.Context())
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case event := <-eventCh:
|
||||
if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed {
|
||||
cfg.l.Error("config file deleted or renamed, ignoring...")
|
||||
continue
|
||||
} else {
|
||||
cfg.reloadReq <- struct{}{}
|
||||
}
|
||||
case err := <-errCh:
|
||||
cfg.l.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (cfg *Config) forEachRoute(do func(alias string, r *R.Route, p *PR.Provider)) {
|
||||
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
|
||||
p.RangeRoutes(func(a string, r *R.Route) {
|
||||
do(a, r, p)
|
||||
})
|
||||
b := E.NewBuilder("errors starting providers")
|
||||
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
|
||||
b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName())))
|
||||
})
|
||||
|
||||
if b.HasError() {
|
||||
logger.Error(b.Build())
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) load() (res E.NestedError) {
|
||||
b := E.NewBuilder("errors loading config")
|
||||
defer b.To(&res)
|
||||
|
||||
cfg.l.Debug("loading config")
|
||||
defer cfg.l.Debug("loaded config")
|
||||
logger.Debug("loading config")
|
||||
defer logger.Debug("loaded config")
|
||||
|
||||
data, err := E.Check(os.ReadFile(common.ConfigPath))
|
||||
if err != nil {
|
||||
|
@ -160,7 +178,7 @@ func (cfg *Config) load() (res E.NestedError) {
|
|||
b.Add(cfg.loadProviders(&model.Providers))
|
||||
|
||||
cfg.value = model
|
||||
R.SetFindMuxDomains(model.MatchDomains)
|
||||
route.SetFindMuxDomains(model.MatchDomains)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -169,8 +187,8 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
|
|||
return
|
||||
}
|
||||
|
||||
cfg.l.Debug("initializing autocert")
|
||||
defer cfg.l.Debug("initialized autocert")
|
||||
logger.Debug("initializing autocert")
|
||||
defer logger.Debug("initialized autocert")
|
||||
|
||||
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
|
||||
if err != nil {
|
||||
|
@ -179,48 +197,34 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
|
|||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedError) {
|
||||
cfg.l.Debug("loading providers")
|
||||
defer cfg.l.Debug("loaded providers")
|
||||
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) {
|
||||
subtask := cfg.task.Subtask("load providers")
|
||||
defer subtask.Finish("done")
|
||||
|
||||
b := E.NewBuilder("errors loading providers")
|
||||
defer b.To(&res)
|
||||
errs := E.NewBuilder("errors loading providers")
|
||||
results := E.NewBuilder("loaded providers")
|
||||
defer errs.To(&outErr)
|
||||
|
||||
for _, filename := range providers.Files {
|
||||
p, err := PR.NewFileProvider(filename)
|
||||
p, err := proxy.NewFileProvider(filename)
|
||||
if err != nil {
|
||||
b.Add(err.Subject(filename))
|
||||
errs.Add(err)
|
||||
continue
|
||||
}
|
||||
cfg.proxyProviders.Store(p.GetName(), p)
|
||||
b.Add(p.LoadRoutes().Subject(filename))
|
||||
cfg.providers.Store(p.GetName(), p)
|
||||
errs.Add(p.LoadRoutes().Subject(filename))
|
||||
results.Addf("%d routes from %s", p.NumRoutes(), filename)
|
||||
}
|
||||
for name, dockerHost := range providers.Docker {
|
||||
p, err := PR.NewDockerProvider(name, dockerHost)
|
||||
p, err := proxy.NewDockerProvider(name, dockerHost)
|
||||
if err != nil {
|
||||
b.Add(err.Subjectf("%s (%s)", name, dockerHost))
|
||||
errs.Add(err.Subjectf("%s (%s)", name, dockerHost))
|
||||
continue
|
||||
}
|
||||
cfg.proxyProviders.Store(p.GetName(), p)
|
||||
b.Add(p.LoadRoutes().Subject(p.GetName()))
|
||||
cfg.providers.Store(p.GetName(), p)
|
||||
errs.Add(p.LoadRoutes().Subject(p.GetName()))
|
||||
results.Addf("%d routes from %s", p.NumRoutes(), name)
|
||||
}
|
||||
logger.Info(results.Build())
|
||||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
|
||||
errors := E.NewBuilder("errors in %s these providers", action)
|
||||
|
||||
cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) {
|
||||
if err := do(p); err != nil {
|
||||
errors.Add(err.Subject(p))
|
||||
}
|
||||
})
|
||||
|
||||
if err := errors.Build(); err != nil {
|
||||
cfg.l.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) stopProviders() {
|
||||
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
|
||||
}
|
||||
|
|
|
@ -6,33 +6,35 @@ import (
|
|||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/homepage"
|
||||
PR "github.com/yusing/go-proxy/internal/proxy/provider"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
func (cfg *Config) DumpEntries() map[string]*types.RawEntry {
|
||||
entries := make(map[string]*types.RawEntry)
|
||||
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) {
|
||||
entries[alias] = r.Entry
|
||||
func DumpEntries() map[string]*entry.RawEntry {
|
||||
entries := make(map[string]*entry.RawEntry)
|
||||
instance.providers.RangeAll(func(_ string, p *proxy.Provider) {
|
||||
p.RangeRoutes(func(alias string, r *route.Route) {
|
||||
entries[alias] = r.Entry
|
||||
})
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func (cfg *Config) DumpProviders() map[string]*PR.Provider {
|
||||
entries := make(map[string]*PR.Provider)
|
||||
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
|
||||
func DumpProviders() map[string]*proxy.Provider {
|
||||
entries := make(map[string]*proxy.Provider)
|
||||
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
|
||||
entries[name] = p
|
||||
})
|
||||
return entries
|
||||
}
|
||||
|
||||
func (cfg *Config) HomepageConfig() homepage.Config {
|
||||
func HomepageConfig() homepage.Config {
|
||||
var proto, port string
|
||||
domains := cfg.value.MatchDomains
|
||||
cert, _ := cfg.autocertProvider.GetCert(nil)
|
||||
domains := instance.value.MatchDomains
|
||||
cert, _ := instance.autocertProvider.GetCert(nil)
|
||||
if cert != nil {
|
||||
proto = "https"
|
||||
port = common.ProxyHTTPSPort
|
||||
|
@ -42,9 +44,9 @@ func (cfg *Config) HomepageConfig() homepage.Config {
|
|||
}
|
||||
|
||||
hpCfg := homepage.NewHomePageConfig()
|
||||
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
|
||||
entry := r.Raw
|
||||
item := entry.Homepage
|
||||
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
|
||||
en := r.Raw
|
||||
item := en.Homepage
|
||||
if item == nil {
|
||||
item = new(homepage.Item)
|
||||
item.Show = true
|
||||
|
@ -63,12 +65,12 @@ func (cfg *Config) HomepageConfig() homepage.Config {
|
|||
)
|
||||
}
|
||||
|
||||
if r.IsDocker() {
|
||||
if entry.IsDocker(r) {
|
||||
if item.Category == "" {
|
||||
item.Category = "Docker"
|
||||
}
|
||||
item.SourceType = string(PR.ProviderTypeDocker)
|
||||
} else if r.UseLoadBalance() {
|
||||
item.SourceType = string(proxy.ProviderTypeDocker)
|
||||
} else if entry.UseLoadBalance(r) {
|
||||
if item.Category == "" {
|
||||
item.Category = "Load-balanced"
|
||||
}
|
||||
|
@ -77,7 +79,7 @@ func (cfg *Config) HomepageConfig() homepage.Config {
|
|||
if item.Category == "" {
|
||||
item.Category = "Others"
|
||||
}
|
||||
item.SourceType = string(PR.ProviderTypeFile)
|
||||
item.SourceType = string(proxy.ProviderTypeFile)
|
||||
}
|
||||
|
||||
if item.URL == "" {
|
||||
|
@ -85,26 +87,26 @@ func (cfg *Config) HomepageConfig() homepage.Config {
|
|||
item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port)
|
||||
}
|
||||
}
|
||||
item.AltURL = r.URL().String()
|
||||
item.AltURL = r.TargetURL().String()
|
||||
|
||||
hpCfg.Add(item)
|
||||
})
|
||||
return hpCfg
|
||||
}
|
||||
|
||||
func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
|
||||
func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
|
||||
routes := make(map[string]any)
|
||||
if len(typeFilter) == 0 || typeFilter[0] == "" {
|
||||
typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream}
|
||||
typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
|
||||
}
|
||||
for _, t := range typeFilter {
|
||||
switch t {
|
||||
case R.RouteTypeReverseProxy:
|
||||
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
|
||||
case route.RouteTypeReverseProxy:
|
||||
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
|
||||
routes[alias] = r
|
||||
})
|
||||
case R.RouteTypeStream:
|
||||
R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) {
|
||||
case route.RouteTypeStream:
|
||||
route.GetStreamProxies().RangeAll(func(alias string, r *route.StreamRoute) {
|
||||
routes[alias] = r
|
||||
})
|
||||
}
|
||||
|
@ -112,12 +114,12 @@ func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
|
|||
return routes
|
||||
}
|
||||
|
||||
func (cfg *Config) Statistics() map[string]any {
|
||||
func Statistics() map[string]any {
|
||||
nTotalStreams := 0
|
||||
nTotalRPs := 0
|
||||
providerStats := make(map[string]PR.ProviderStats)
|
||||
providerStats := make(map[string]proxy.ProviderStats)
|
||||
|
||||
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
|
||||
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
|
||||
providerStats[name] = p.Statistics()
|
||||
})
|
||||
|
||||
|
@ -133,9 +135,9 @@ func (cfg *Config) Statistics() map[string]any {
|
|||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) FindRoute(alias string) *R.Route {
|
||||
return F.MapFind(cfg.proxyProviders,
|
||||
func(p *PR.Provider) (*R.Route, bool) {
|
||||
func FindRoute(alias string) *route.Route {
|
||||
return F.MapFind(instance.providers,
|
||||
func(p *proxy.Provider) (*route.Route, bool) {
|
||||
if route, ok := p.GetRoute(alias); ok {
|
||||
return route, true
|
||||
}
|
||||
|
|
24
internal/config/types/config.go
Normal file
24
internal/config/types/config.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package types
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
Providers ProxyProviders `json:"providers" yaml:",flow"`
|
||||
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
|
||||
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
|
||||
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
|
||||
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
|
||||
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
|
||||
}
|
||||
ProxyProviders struct {
|
||||
Files []string `json:"include" yaml:"include"` // docker, file
|
||||
Docker map[string]string `json:"docker" yaml:"docker"`
|
||||
}
|
||||
)
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Providers: ProxyProviders{},
|
||||
TimeoutShutdown: 3,
|
||||
RedirectToHTTPS: false,
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
@ -36,22 +37,13 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
task := common.NewTask("close all docker client")
|
||||
defer task.Finished()
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
}
|
||||
})
|
||||
clientMap.Clear()
|
||||
return
|
||||
task.GlobalTask("close docker clients").OnComplete("", func() {
|
||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (c *SharedClient) Connected() bool {
|
||||
|
@ -141,19 +133,10 @@ func ConnectClient(host string) (Client, E.NestedError) {
|
|||
<-c.refCount.Zero()
|
||||
clientMap.Delete(c.key)
|
||||
|
||||
if c.Client != nil {
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
c.Client = nil
|
||||
c.l.Debugf("client closed")
|
||||
}
|
||||
}()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func CloseAllClients() {
|
||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||
c.Client.Close()
|
||||
})
|
||||
clientMap.Clear()
|
||||
logger.Debug("closed all clients")
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package docker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
|
@ -16,10 +17,13 @@ type ClientInfo struct {
|
|||
}
|
||||
|
||||
var listOptions = container.ListOptions{
|
||||
// created|restarting|running|removing|paused|exited|dead
|
||||
// Filters: filters.NewArgs(
|
||||
// filters.Arg("health", "healthy"),
|
||||
// filters.Arg("health", "none"),
|
||||
// filters.Arg("health", "starting"),
|
||||
// filters.Arg("status", "created"),
|
||||
// filters.Arg("status", "restarting"),
|
||||
// filters.Arg("status", "running"),
|
||||
// filters.Arg("status", "paused"),
|
||||
// filters.Arg("status", "exited"),
|
||||
// ),
|
||||
All: true,
|
||||
}
|
||||
|
@ -31,7 +35,7 @@ func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedE
|
|||
}
|
||||
defer dockerClient.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout"))
|
||||
defer cancel()
|
||||
|
||||
var containers []types.Container
|
||||
|
|
|
@ -32,11 +32,11 @@ type (
|
|||
IsExcluded bool `json:"is_excluded" yaml:"-"`
|
||||
IsExplicit bool `json:"is_explicit" yaml:"-"`
|
||||
IsDatabase bool `json:"is_database" yaml:"-"`
|
||||
IdleTimeout string `json:"idle_timeout" yaml:"-"`
|
||||
WakeTimeout string `json:"wake_timeout" yaml:"-"`
|
||||
StopMethod string `json:"stop_method" yaml:"-"`
|
||||
StopTimeout string `json:"stop_timeout" yaml:"-"` // stop_method = "stop" only
|
||||
StopSignal string `json:"stop_signal" yaml:"-"` // stop_method = "stop" | "kill" only
|
||||
IdleTimeout string `json:"idle_timeout,omitempty" yaml:"-"`
|
||||
WakeTimeout string `json:"wake_timeout,omitempty" yaml:"-"`
|
||||
StopMethod string `json:"stop_method,omitempty" yaml:"-"`
|
||||
StopTimeout string `json:"stop_timeout,omitempty" yaml:"-"` // stop_method = "stop" only
|
||||
StopSignal string `json:"stop_signal,omitempty" yaml:"-"` // stop_method = "stop" | "kill" only
|
||||
Running bool `json:"running" yaml:"-"`
|
||||
}
|
||||
)
|
||||
|
|
112
internal/docker/idlewatcher/config/config.go
Normal file
112
internal/docker/idlewatcher/config/config.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package idlewatcher
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
|
||||
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
|
||||
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
|
||||
StopMethod StopMethod `json:"stop_method,omitempty"`
|
||||
StopSignal Signal `json:"stop_signal,omitempty"`
|
||||
|
||||
DockerHost string `json:"docker_host,omitempty"`
|
||||
ContainerName string `json:"container_name,omitempty"`
|
||||
ContainerID string `json:"container_id,omitempty"`
|
||||
ContainerRunning bool `json:"container_running,omitempty"`
|
||||
}
|
||||
StopMethod string
|
||||
Signal string
|
||||
)
|
||||
|
||||
const (
|
||||
StopMethodPause StopMethod = "pause"
|
||||
StopMethodStop StopMethod = "stop"
|
||||
StopMethodKill StopMethod = "kill"
|
||||
)
|
||||
|
||||
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
|
||||
if cont == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if cont.IdleTimeout == "" {
|
||||
return &Config{
|
||||
DockerHost: cont.DockerHost,
|
||||
ContainerName: cont.ContainerName,
|
||||
ContainerID: cont.ContainerID,
|
||||
ContainerRunning: cont.Running,
|
||||
}, nil
|
||||
}
|
||||
|
||||
b := E.NewBuilder("invalid idlewatcher config")
|
||||
defer b.To(&res)
|
||||
|
||||
idleTimeout, err := validateDurationPostitive(cont.IdleTimeout)
|
||||
b.Add(err.Subjectf("%s", "idle_timeout"))
|
||||
|
||||
wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout)
|
||||
b.Add(err.Subjectf("%s", "wake_timeout"))
|
||||
|
||||
stopTimeout, err := validateDurationPostitive(cont.StopTimeout)
|
||||
b.Add(err.Subjectf("%s", "stop_timeout"))
|
||||
|
||||
stopMethod, err := validateStopMethod(cont.StopMethod)
|
||||
b.Add(err)
|
||||
|
||||
signal, err := validateSignal(cont.StopSignal)
|
||||
b.Add(err)
|
||||
|
||||
if err := b.Build(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return &Config{
|
||||
IdleTimeout: idleTimeout,
|
||||
WakeTimeout: wakeTimeout,
|
||||
StopTimeout: int(stopTimeout.Seconds()),
|
||||
StopMethod: stopMethod,
|
||||
StopSignal: signal,
|
||||
|
||||
DockerHost: cont.DockerHost,
|
||||
ContainerName: cont.ContainerName,
|
||||
ContainerID: cont.ContainerID,
|
||||
ContainerRunning: cont.Running,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, E.Invalid("duration", value).With(err)
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, E.Invalid("duration", "negative value")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func validateSignal(s string) (Signal, E.NestedError) {
|
||||
switch s {
|
||||
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
|
||||
"INT", "TERM", "HUP", "QUIT":
|
||||
return Signal(s), nil
|
||||
}
|
||||
|
||||
return "", E.Invalid("signal", s)
|
||||
}
|
||||
|
||||
func validateStopMethod(s string) (StopMethod, E.NestedError) {
|
||||
sm := StopMethod(s)
|
||||
switch sm {
|
||||
case StopMethodPause, StopMethodStop, StopMethodKill:
|
||||
return sm, nil
|
||||
default:
|
||||
return "", E.Invalid("stop_method", sm)
|
||||
}
|
||||
}
|
|
@ -20,16 +20,15 @@ var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(lo
|
|||
|
||||
const headerCheckRedirect = "X-Goproxy-Check-Redirect"
|
||||
|
||||
func (w *Watcher) makeRespBody(format string, args ...any) []byte {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
func (w *Watcher) makeLoadingPageBody() []byte {
|
||||
msg := fmt.Sprintf("%s is starting...", w.ContainerName)
|
||||
|
||||
data := new(templateData)
|
||||
data.CheckRedirectHeader = headerCheckRedirect
|
||||
data.Title = w.ContainerName
|
||||
data.Message = strings.ReplaceAll(msg, "\n", "<br>")
|
||||
data.Message = strings.ReplaceAll(data.Message, " ", " ")
|
||||
data.Message = strings.ReplaceAll(msg, " ", " ")
|
||||
|
||||
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough
|
||||
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect)))
|
||||
err := loadingPageTmpl.Execute(buf, data)
|
||||
if err != nil { // should never happen in production
|
||||
panic(err)
|
|
@ -1,197 +1,133 @@
|
|||
package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type Waker struct {
|
||||
*Watcher
|
||||
type Waker interface {
|
||||
health.HealthMonitor
|
||||
http.Handler
|
||||
net.Stream
|
||||
}
|
||||
|
||||
type waker struct {
|
||||
_ U.NoCopy
|
||||
|
||||
client *http.Client
|
||||
rp *gphttp.ReverseProxy
|
||||
stream net.Stream
|
||||
hc health.HealthChecker
|
||||
|
||||
ready atomic.Bool
|
||||
}
|
||||
|
||||
func NewWaker(w *Watcher, rp *gphttp.ReverseProxy) *Waker {
|
||||
return &Waker{
|
||||
Watcher: w,
|
||||
client: &http.Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Transport: rp.Transport,
|
||||
},
|
||||
rp: rp,
|
||||
const (
|
||||
idleWakerCheckInterval = 100 * time.Millisecond
|
||||
idleWakerCheckTimeout = time.Second
|
||||
)
|
||||
|
||||
// TODO: support stream
|
||||
|
||||
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) {
|
||||
hcCfg := entry.HealthCheckConfig()
|
||||
hcCfg.Timeout = idleWakerCheckTimeout
|
||||
|
||||
waker := &waker{
|
||||
rp: rp,
|
||||
stream: stream,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
shouldNext := w.wake(rw, r)
|
||||
if !shouldNext {
|
||||
return
|
||||
watcher, err := registerWatcher(providerSubTask, entry, waker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w.rp.ServeHTTP(rw, r)
|
||||
|
||||
if rp != nil {
|
||||
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
|
||||
} else if stream != nil {
|
||||
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
|
||||
} else {
|
||||
panic("both nil")
|
||||
}
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
/* HealthMonitor interface */
|
||||
|
||||
func (w *Waker) Start() {}
|
||||
|
||||
func (w *Waker) Stop() {
|
||||
w.Unregister()
|
||||
// lifetime should follow route provider
|
||||
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) {
|
||||
return newWaker(providerSubTask, entry, rp, nil)
|
||||
}
|
||||
|
||||
func (w *Waker) UpdateConfig(config health.HealthCheckConfig) {
|
||||
panic("use idlewatcher.Register instead")
|
||||
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) {
|
||||
return newWaker(providerSubTask, entry, nil, stream)
|
||||
}
|
||||
|
||||
func (w *Waker) Name() string {
|
||||
// Start implements health.HealthMonitor.
|
||||
func (w *Watcher) Start(routeSubTask task.Task) E.NestedError {
|
||||
w.task.OnComplete("stop route", func() {
|
||||
routeSubTask.Parent().Finish("watcher stopped")
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finish implements health.HealthMonitor.
|
||||
func (w *Watcher) Finish(reason string) {}
|
||||
|
||||
// Name implements health.HealthMonitor.
|
||||
func (w *Watcher) Name() string {
|
||||
return w.String()
|
||||
}
|
||||
|
||||
func (w *Waker) String() string {
|
||||
return string(w.Alias)
|
||||
// String implements health.HealthMonitor.
|
||||
func (w *Watcher) String() string {
|
||||
return w.ContainerName
|
||||
}
|
||||
|
||||
func (w *Waker) Status() health.Status {
|
||||
if w.ready.Load() {
|
||||
return health.StatusHealthy
|
||||
}
|
||||
if !w.ContainerRunning {
|
||||
return health.StatusNapping
|
||||
}
|
||||
return health.StatusStarting
|
||||
}
|
||||
|
||||
func (w *Waker) Uptime() time.Duration {
|
||||
// Uptime implements health.HealthMonitor.
|
||||
func (w *Watcher) Uptime() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (w *Waker) MarshalJSON() ([]byte, error) {
|
||||
var url types.URL
|
||||
if w.URL.String() != "http://:0" {
|
||||
url = w.URL
|
||||
// Status implements health.HealthMonitor.
|
||||
func (w *Watcher) Status() health.Status {
|
||||
if !w.ContainerRunning {
|
||||
return health.StatusNapping
|
||||
}
|
||||
|
||||
if w.ready.Load() {
|
||||
return health.StatusHealthy
|
||||
}
|
||||
|
||||
healthy, _, err := w.hc.CheckHealth()
|
||||
switch {
|
||||
case err != nil:
|
||||
return health.StatusError
|
||||
case healthy:
|
||||
w.ready.Store(true)
|
||||
return health.StatusHealthy
|
||||
default:
|
||||
return health.StatusStarting
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements health.HealthMonitor.
|
||||
func (w *Watcher) MarshalJSON() ([]byte, error) {
|
||||
var url net.URL
|
||||
if w.hc.URL().Port() != "0" {
|
||||
url = w.hc.URL()
|
||||
}
|
||||
return (&health.JSONRepresentation{
|
||||
Name: w.Name(),
|
||||
Status: w.Status(),
|
||||
Config: &health.HealthCheckConfig{
|
||||
Interval: w.IdleTimeout,
|
||||
Timeout: w.WakeTimeout,
|
||||
},
|
||||
URL: url,
|
||||
Config: w.hc.Config(),
|
||||
URL: url,
|
||||
}).MarshalJSON()
|
||||
}
|
||||
|
||||
/* End of HealthMonitor interface */
|
||||
|
||||
func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
|
||||
w.resetIdleTimer()
|
||||
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
}
|
||||
|
||||
// pass through if container is ready
|
||||
if w.ready.Load() {
|
||||
return true
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout)
|
||||
defer cancel()
|
||||
|
||||
accept := gphttp.GetAccept(r.Header)
|
||||
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
|
||||
|
||||
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
|
||||
if !isCheckRedirect && acceptHTML {
|
||||
// Send a loading response to the client
|
||||
body := w.makeRespBody("%s waking up...", w.ContainerName)
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
rw.Header().Add("Cache-Control", "no-cache")
|
||||
rw.Header().Add("Cache-Control", "no-store")
|
||||
rw.Header().Add("Cache-Control", "must-revalidate")
|
||||
if _, err := rw.Write(body); err != nil {
|
||||
w.l.Errorf("error writing http response: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
w.l.Debug("wake signal received")
|
||||
err := w.wakeIfStopped()
|
||||
if err != nil {
|
||||
w.l.Error(E.FailWith("wake", err))
|
||||
http.Error(rw, "Error waking container", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// maybe another request came in while we were waiting for the wake
|
||||
if w.ready.Load() {
|
||||
if isCheckRedirect {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
wakeReq, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodHead,
|
||||
w.URL.String(),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
w.l.Errorf("new request err to %s: %s", r.URL, err)
|
||||
http.Error(rw, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
wakeResp, err := w.client.Do(wakeReq)
|
||||
if err == nil && wakeResp.StatusCode != http.StatusServiceUnavailable {
|
||||
w.ready.Store(true)
|
||||
w.l.Debug("awaken")
|
||||
if isCheckRedirect {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
logrus.Infof("container %s is ready, passing through to %s", w.Alias, w.rp.TargetURL)
|
||||
return true
|
||||
}
|
||||
|
||||
// retry until the container is ready or timeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// static HealthMonitor interface check
|
||||
func (w *Waker) _() health.HealthMonitor {
|
||||
return w
|
||||
}
|
||||
|
|
105
internal/docker/idlewatcher/waker_http.go
Normal file
105
internal/docker/idlewatcher/waker_http.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
shouldNext := w.wakeFromHTTP(rw, r)
|
||||
if !shouldNext {
|
||||
return
|
||||
}
|
||||
w.rp.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
|
||||
w.resetIdleTimer()
|
||||
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
}
|
||||
|
||||
// pass through if container is already ready
|
||||
if w.ready.Load() {
|
||||
return true
|
||||
}
|
||||
|
||||
accept := gphttp.GetAccept(r.Header)
|
||||
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
|
||||
|
||||
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
|
||||
if !isCheckRedirect && acceptHTML {
|
||||
// Send a loading response to the client
|
||||
body := w.makeLoadingPageBody()
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
rw.Header().Add("Cache-Control", "no-cache")
|
||||
rw.Header().Add("Cache-Control", "no-store")
|
||||
rw.Header().Add("Cache-Control", "must-revalidate")
|
||||
rw.Header().Add("Connection", "close")
|
||||
if _, err := rw.Write(body); err != nil {
|
||||
w.l.Errorf("error writing http response: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
|
||||
defer cancel()
|
||||
|
||||
checkCancelled := func() bool {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context()))
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
w.l.Debugf("wake cancelled: %s", context.Cause(ctx))
|
||||
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if checkCancelled() {
|
||||
return false
|
||||
}
|
||||
|
||||
w.l.Debug("wake signal received")
|
||||
err := w.wakeIfStopped()
|
||||
if err != nil {
|
||||
w.l.Error(E.FailWith("wake", err))
|
||||
http.Error(rw, "Error waking container", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if checkCancelled() {
|
||||
return false
|
||||
}
|
||||
|
||||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
if isCheckRedirect {
|
||||
logrus.Debugf("container %s is ready, redirecting...", w.String())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
|
||||
return true
|
||||
}
|
||||
|
||||
// retry until the container is ready or timeout
|
||||
time.Sleep(idleWakerCheckInterval)
|
||||
}
|
||||
}
|
87
internal/docker/idlewatcher/waker_stream.go
Normal file
87
internal/docker/idlewatcher/waker_stream.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package idlewatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Setup() error {
|
||||
return w.stream.Setup()
|
||||
}
|
||||
|
||||
// Accept implements types.Stream.
|
||||
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
|
||||
conn, err = w.stream.Accept()
|
||||
// timeout means no connection is accepted
|
||||
var nErr *net.OpError
|
||||
ok := errors.As(err, &nErr)
|
||||
if ok && nErr.Timeout() {
|
||||
return
|
||||
}
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w.stream.Accept()
|
||||
}
|
||||
|
||||
// CloseListeners implements types.Stream.
|
||||
func (w *Watcher) CloseListeners() {
|
||||
w.stream.CloseListeners()
|
||||
}
|
||||
|
||||
// Handle implements types.Stream.
|
||||
func (w *Watcher) Handle(conn types.StreamConn) error {
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.stream.Handle(conn)
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromStream() error {
|
||||
// pass through if container is already ready
|
||||
if w.ready.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
w.l.Debug("wake signal received")
|
||||
wakeErr := w.wakeIfStopped()
|
||||
if wakeErr != nil {
|
||||
wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr)
|
||||
w.l.Error(wakeErr)
|
||||
return wakeErr
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.WakeTimeout, errors.New("wake timeout"))
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
cause := w.task.FinishCause()
|
||||
w.l.Debugf("wake cancelled: %s", cause)
|
||||
return cause
|
||||
case <-ctx.Done():
|
||||
cause := context.Cause(ctx)
|
||||
w.l.Debugf("wake cancelled: %s", cause)
|
||||
return cause
|
||||
default:
|
||||
}
|
||||
|
||||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
// retry until the container is ready or timeout
|
||||
time.Sleep(idleWakerCheckInterval)
|
||||
}
|
||||
}
|
|
@ -2,191 +2,193 @@ package idlewatcher
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
type (
|
||||
Watcher struct {
|
||||
*P.ReverseProxyEntry
|
||||
_ U.NoCopy
|
||||
|
||||
client D.Client
|
||||
*idlewatcher.Config
|
||||
*waker
|
||||
|
||||
ready atomic.Bool // whether the site is ready to accept connection
|
||||
client D.Client
|
||||
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
|
||||
|
||||
ticker *time.Ticker
|
||||
|
||||
task common.Task
|
||||
cancel context.CancelFunc
|
||||
|
||||
refCount *U.RefCount
|
||||
|
||||
l logrus.FieldLogger
|
||||
ticker *time.Ticker
|
||||
task task.Task
|
||||
l *logrus.Entry
|
||||
}
|
||||
|
||||
WakeDone <-chan error
|
||||
WakeFunc func() WakeDone
|
||||
StopCallback func() E.NestedError
|
||||
StopCallback func() error
|
||||
)
|
||||
|
||||
var (
|
||||
watcherMap = F.NewMapOf[string, *Watcher]()
|
||||
watcherMapMu sync.Mutex
|
||||
|
||||
portHistoryMap = F.NewMapOf[PT.Alias, string]()
|
||||
|
||||
logger = logrus.WithField("module", "idle_watcher")
|
||||
)
|
||||
|
||||
func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
|
||||
failure := E.Failure("idle_watcher register")
|
||||
const dockerReqTimeout = 3 * time.Second
|
||||
|
||||
if entry.IdleTimeout == 0 {
|
||||
return nil, failure.With(E.Invalid("idle_timeout", 0))
|
||||
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) {
|
||||
failure := E.Failure("idle_watcher register")
|
||||
cfg := entry.IdlewatcherConfig()
|
||||
|
||||
if cfg.IdleTimeout == 0 {
|
||||
panic("should not reach here")
|
||||
}
|
||||
|
||||
watcherMapMu.Lock()
|
||||
defer watcherMapMu.Unlock()
|
||||
|
||||
key := entry.ContainerID
|
||||
|
||||
if entry.URL.Port() != "0" {
|
||||
portHistoryMap.Store(entry.Alias, entry.URL.Port())
|
||||
}
|
||||
key := cfg.ContainerID
|
||||
|
||||
if w, ok := watcherMap.Load(key); ok {
|
||||
w.refCount.Add()
|
||||
w.ReverseProxyEntry = entry
|
||||
w.Config = cfg
|
||||
w.waker = waker
|
||||
w.resetIdleTimer()
|
||||
return w, nil
|
||||
}
|
||||
|
||||
client, err := D.ConnectClient(entry.DockerHost)
|
||||
client, err := D.ConnectClient(cfg.DockerHost)
|
||||
if err.HasError() {
|
||||
return nil, failure.With(err)
|
||||
}
|
||||
|
||||
w := &Watcher{
|
||||
ReverseProxyEntry: entry,
|
||||
client: client,
|
||||
refCount: U.NewRefCounter(),
|
||||
ticker: time.NewTicker(entry.IdleTimeout),
|
||||
l: logger.WithField("container", entry.ContainerName),
|
||||
Config: cfg,
|
||||
waker: waker,
|
||||
client: client,
|
||||
task: providerSubtask,
|
||||
ticker: time.NewTicker(cfg.IdleTimeout),
|
||||
l: logger.WithField("container", cfg.ContainerName),
|
||||
}
|
||||
w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias)
|
||||
w.stopByMethod = w.getStopCallback()
|
||||
|
||||
watcherMap.Store(key, w)
|
||||
|
||||
go w.watchUntilCancel()
|
||||
go func() {
|
||||
cause := w.watchUntilDestroy()
|
||||
|
||||
watcherMapMu.Lock()
|
||||
watcherMap.Delete(w.ContainerID)
|
||||
watcherMapMu.Unlock()
|
||||
|
||||
w.ticker.Stop()
|
||||
w.client.Close()
|
||||
w.task.Finish(cause.Error())
|
||||
}()
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) Unregister() {
|
||||
w.refCount.Sub()
|
||||
}
|
||||
|
||||
func (w *Watcher) containerStop() error {
|
||||
return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{
|
||||
func (w *Watcher) containerStop(ctx context.Context) error {
|
||||
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
|
||||
Signal: string(w.StopSignal),
|
||||
Timeout: &w.StopTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) containerPause() error {
|
||||
return w.client.ContainerPause(w.task.Context(), w.ContainerID)
|
||||
func (w *Watcher) containerPause(ctx context.Context) error {
|
||||
return w.client.ContainerPause(ctx, w.ContainerID)
|
||||
}
|
||||
|
||||
func (w *Watcher) containerKill() error {
|
||||
return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal))
|
||||
func (w *Watcher) containerKill(ctx context.Context) error {
|
||||
return w.client.ContainerKill(ctx, w.ContainerID, string(w.StopSignal))
|
||||
}
|
||||
|
||||
func (w *Watcher) containerUnpause() error {
|
||||
return w.client.ContainerUnpause(w.task.Context(), w.ContainerID)
|
||||
func (w *Watcher) containerUnpause(ctx context.Context) error {
|
||||
return w.client.ContainerUnpause(ctx, w.ContainerID)
|
||||
}
|
||||
|
||||
func (w *Watcher) containerStart() error {
|
||||
return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{})
|
||||
func (w *Watcher) containerStart(ctx context.Context) error {
|
||||
return w.client.ContainerStart(ctx, w.ContainerID, container.StartOptions{})
|
||||
}
|
||||
|
||||
func (w *Watcher) containerStatus() (string, E.NestedError) {
|
||||
func (w *Watcher) containerStatus() (string, error) {
|
||||
if !w.client.Connected() {
|
||||
return "", E.Failure("docker client closed")
|
||||
return "", errors.New("docker client not connected")
|
||||
}
|
||||
json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID)
|
||||
ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout"))
|
||||
defer cancel()
|
||||
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
|
||||
if err != nil {
|
||||
return "", E.FailWith("inspect container", err)
|
||||
return "", fmt.Errorf("failed to inspect container: %w", err)
|
||||
}
|
||||
return json.State.Status, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeIfStopped() E.NestedError {
|
||||
if w.ready.Load() || w.ContainerRunning {
|
||||
func (w *Watcher) wakeIfStopped() error {
|
||||
if w.ContainerRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
status, err := w.containerStatus()
|
||||
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
|
||||
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
|
||||
defer cancel()
|
||||
|
||||
// !Hard coded here since theres no constants from Docker API
|
||||
switch status {
|
||||
case "exited", "dead":
|
||||
return E.From(w.containerStart())
|
||||
return w.containerStart(ctx)
|
||||
case "paused":
|
||||
return E.From(w.containerUnpause())
|
||||
return w.containerUnpause(ctx)
|
||||
case "running":
|
||||
return nil
|
||||
default:
|
||||
return E.Unexpected("container state", status)
|
||||
panic("should not reach here")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) getStopCallback() StopCallback {
|
||||
var cb func() error
|
||||
var cb func(context.Context) error
|
||||
switch w.StopMethod {
|
||||
case PT.StopMethodPause:
|
||||
case idlewatcher.StopMethodPause:
|
||||
cb = w.containerPause
|
||||
case PT.StopMethodStop:
|
||||
case idlewatcher.StopMethodStop:
|
||||
cb = w.containerStop
|
||||
case PT.StopMethodKill:
|
||||
case idlewatcher.StopMethodKill:
|
||||
cb = w.containerKill
|
||||
default:
|
||||
panic("should not reach here")
|
||||
}
|
||||
return func() E.NestedError {
|
||||
status, err := w.containerStatus()
|
||||
if err.HasError() {
|
||||
return err
|
||||
}
|
||||
if status != "running" {
|
||||
return nil
|
||||
}
|
||||
return E.From(cb())
|
||||
return func() error {
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
|
||||
defer cancel()
|
||||
return cb(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) resetIdleTimer() {
|
||||
w.l.Trace("reset idle timer")
|
||||
w.ticker.Reset(w.IdleTimeout)
|
||||
}
|
||||
|
||||
func (w *Watcher) watchUntilCancel() {
|
||||
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
|
||||
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{
|
||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) {
|
||||
eventTask = w.task.Subtask("watcher for %s", w.ContainerID)
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
|
||||
Filters: W.NewDockerFilter(
|
||||
W.DockerFilterContainer,
|
||||
W.DockerrFilterContainer(w.ContainerID),
|
||||
|
@ -194,34 +196,47 @@ func (w *Watcher) watchUntilCancel() {
|
|||
W.DockerFilterStop,
|
||||
W.DockerFilterDie,
|
||||
W.DockerFilterKill,
|
||||
W.DockerFilterDestroy,
|
||||
W.DockerFilterPause,
|
||||
W.DockerFilterUnpause,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
w.cancel()
|
||||
w.ticker.Stop()
|
||||
w.client.Close()
|
||||
watcherMap.Delete(w.ContainerID)
|
||||
w.task.Finished()
|
||||
}()
|
||||
// watchUntilDestroy waits for the container to be created, started, or unpaused,
|
||||
// and then reset the idle timer.
|
||||
//
|
||||
// When the container is stopped, paused,
|
||||
// or killed, the idle timer is stopped and the ContainerRunning flag is set to false.
|
||||
//
|
||||
// When the idle timer fires, the container is stopped according to the
|
||||
// stop method.
|
||||
//
|
||||
// it exits only if the context is canceled, the container is destroyed,
|
||||
// errors occured on docker client, or route provider died (mainly caused by config reload).
|
||||
func (w *Watcher) watchUntilDestroy() error {
|
||||
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
|
||||
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
w.l.Debug("stopped by context done")
|
||||
return
|
||||
case <-w.refCount.Zero():
|
||||
w.l.Debug("stopped by zero ref count")
|
||||
return
|
||||
cause := context.Cause(w.task.Context())
|
||||
w.l.Debugf("watcher stopped by context done: %s", cause)
|
||||
return cause
|
||||
case err := <-dockerEventErrCh:
|
||||
if err != nil && err.IsNot(context.Canceled) {
|
||||
w.l.Error(E.FailWith("docker watcher", err))
|
||||
return
|
||||
return err.Error()
|
||||
}
|
||||
case e := <-dockerEventCh:
|
||||
switch {
|
||||
case e.Action == events.ActionContainerDestroy:
|
||||
w.ContainerRunning = false
|
||||
w.ready.Store(false)
|
||||
w.l.Info("watcher stopped by container destruction")
|
||||
return errors.New("container destroyed")
|
||||
// create / start / unpause
|
||||
case e.Action.IsContainerWake():
|
||||
w.ContainerRunning = true
|
||||
|
@ -229,18 +244,31 @@ func (w *Watcher) watchUntilCancel() {
|
|||
w.l.Info("container awaken")
|
||||
case e.Action.IsContainerSleep(): // stop / pause / kil
|
||||
w.ContainerRunning = false
|
||||
w.ticker.Stop()
|
||||
w.ready.Store(false)
|
||||
w.ticker.Stop()
|
||||
default:
|
||||
w.l.Errorf("unexpected docker event: %s", e)
|
||||
}
|
||||
// container name changed should also change the container id
|
||||
if w.ContainerName != e.ActorName {
|
||||
w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName)
|
||||
w.ContainerName = e.ActorName
|
||||
}
|
||||
if w.ContainerID != e.ActorID {
|
||||
w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID)
|
||||
w.ContainerID = e.ActorID
|
||||
// recreate event stream
|
||||
eventTask.Finish("recreate event stream")
|
||||
eventTask, dockerEventCh, dockerEventErrCh = w.getEventCh(dockerWatcher)
|
||||
}
|
||||
case <-w.ticker.C:
|
||||
w.l.Debug("idle timeout")
|
||||
w.ticker.Stop()
|
||||
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
|
||||
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
|
||||
} else {
|
||||
w.l.Info("stopped by idle timeout")
|
||||
if w.ContainerRunning {
|
||||
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err)
|
||||
} else {
|
||||
w.l.Info("container stopped by idle timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package docker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -19,7 +20,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError)
|
|||
}
|
||||
|
||||
func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
|
||||
defer cancel()
|
||||
|
||||
json, err := c.ContainerInspect(ctx, containerID)
|
||||
|
|
|
@ -17,35 +17,36 @@ type builder struct {
|
|||
}
|
||||
|
||||
func NewBuilder(format string, args ...any) Builder {
|
||||
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
|
||||
if len(args) > 0 {
|
||||
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
|
||||
}
|
||||
return Builder{&builder{message: format}}
|
||||
}
|
||||
|
||||
// adding nil / nil is no-op,
|
||||
// you may safely pass expressions returning error to it.
|
||||
func (b Builder) Add(err NestedError) Builder {
|
||||
func (b Builder) Add(err NestedError) {
|
||||
if err != nil {
|
||||
b.Lock()
|
||||
b.errors = append(b.errors, err)
|
||||
b.Unlock()
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b Builder) AddE(err error) Builder {
|
||||
return b.Add(From(err))
|
||||
func (b Builder) AddE(err error) {
|
||||
b.Add(From(err))
|
||||
}
|
||||
|
||||
func (b Builder) Addf(format string, args ...any) Builder {
|
||||
return b.Add(errorf(format, args...))
|
||||
func (b Builder) Addf(format string, args ...any) {
|
||||
b.Add(errorf(format, args...))
|
||||
}
|
||||
|
||||
func (b Builder) AddRangeE(errs ...error) Builder {
|
||||
func (b Builder) AddRangeE(errs ...error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
for _, err := range errs {
|
||||
b.AddE(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Build builds a NestedError based on the errors collected in the Builder.
|
||||
|
|
|
@ -2,6 +2,7 @@ package error
|
|||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
|
@ -16,6 +17,7 @@ var (
|
|||
ErrOutOfRange = stderrors.New("out of range")
|
||||
ErrTypeError = stderrors.New("type error")
|
||||
ErrTypeMismatch = stderrors.New("type mismatch")
|
||||
ErrPanicRecv = stderrors.New("panic")
|
||||
)
|
||||
|
||||
const fmtSubjectWhat = "%w %v: %q"
|
||||
|
@ -75,3 +77,7 @@ func TypeError2(subject any, from, to reflect.Value) NestedError {
|
|||
func TypeMismatch[Expect any](value any) NestedError {
|
||||
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
|
||||
}
|
||||
|
||||
func PanicRecv(format string, args ...any) NestedError {
|
||||
return errorf("%w%s", ErrPanicRecv, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
|
|
@ -4,18 +4,20 @@ import (
|
|||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
)
|
||||
|
||||
type ipHash struct {
|
||||
*LoadBalancer
|
||||
realIP *middleware.Middleware
|
||||
pool servers
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newIPHash() impl {
|
||||
impl := &ipHash{LoadBalancer: lb}
|
||||
impl := new(ipHash)
|
||||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
|
@ -26,10 +28,37 @@ func (lb *LoadBalancer) newIPHash() impl {
|
|||
}
|
||||
return impl
|
||||
}
|
||||
func (ipHash) OnAddServer(srv *Server) {}
|
||||
func (ipHash) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||
func (impl *ipHash) OnAddServer(srv *Server) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
for i, s := range impl.pool {
|
||||
if s == srv {
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
impl.pool[i] = srv
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
impl.pool = append(impl.pool, srv)
|
||||
}
|
||||
|
||||
func (impl *ipHash) OnRemoveServer(srv *Server) {
|
||||
impl.mu.Lock()
|
||||
defer impl.mu.Unlock()
|
||||
|
||||
for i, s := range impl.pool {
|
||||
if s == srv {
|
||||
impl.pool[i] = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||
if impl.realIP != nil {
|
||||
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
|
||||
} else {
|
||||
|
@ -37,7 +66,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
}
|
||||
|
||||
func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
|
@ -45,10 +74,12 @@ func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
if impl.pool[idx].Status().Bad() {
|
||||
|
||||
srv := impl.pool[idx]
|
||||
if srv == nil || srv.Status().Bad() {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
}
|
||||
impl.pool[idx].ServeHTTP(rw, r)
|
||||
srv.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func hashIP(ip string) uint32 {
|
||||
|
|
|
@ -5,8 +5,9 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/log"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
|
@ -28,7 +29,9 @@ type (
|
|||
impl
|
||||
*Config
|
||||
|
||||
pool servers
|
||||
task task.Task
|
||||
|
||||
pool Pool
|
||||
poolMu sync.Mutex
|
||||
|
||||
sumWeight weightType
|
||||
|
@ -41,11 +44,35 @@ type (
|
|||
const maxWeight weightType = 100
|
||||
|
||||
func New(cfg *Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)}
|
||||
lb := &LoadBalancer{
|
||||
Config: new(Config),
|
||||
pool: newPool(),
|
||||
task: task.DummyTask(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = routeSubtask
|
||||
lb.task.OnComplete("loadbalancer cleanup", func() {
|
||||
if lb.impl != nil {
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
lb.impl.OnRemoveServer(v)
|
||||
})
|
||||
}
|
||||
lb.pool.Clear()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (lb *LoadBalancer) Finish(reason string) {
|
||||
lb.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) updateImpl() {
|
||||
switch lb.Mode {
|
||||
case Unset, RoundRobin:
|
||||
|
@ -57,9 +84,9 @@ func (lb *LoadBalancer) updateImpl() {
|
|||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
}
|
||||
for _, srv := range lb.pool {
|
||||
lb.pool.RangeAll(func(_ string, srv *Server) {
|
||||
lb.impl.OnAddServer(srv)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
|
||||
|
@ -91,55 +118,60 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
|
|||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.pool = append(lb.pool, srv)
|
||||
if lb.pool.Has(srv.Name) {
|
||||
old, _ := lb.pool.Load(srv.Name)
|
||||
lb.sumWeight -= old.Weight
|
||||
lb.impl.OnRemoveServer(old)
|
||||
}
|
||||
lb.pool.Store(srv.Name, srv)
|
||||
lb.sumWeight += srv.Weight
|
||||
|
||||
lb.Rebalance()
|
||||
lb.rebalance()
|
||||
lb.impl.OnAddServer(srv)
|
||||
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool))
|
||||
logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.sumWeight -= srv.Weight
|
||||
lb.Rebalance()
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
for i, s := range lb.pool {
|
||||
if s == srv {
|
||||
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if lb.IsEmpty() {
|
||||
lb.pool = nil
|
||||
if !lb.pool.Has(srv.Name) {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool))
|
||||
lb.pool.Delete(srv.Name)
|
||||
|
||||
lb.sumWeight -= srv.Weight
|
||||
lb.rebalance()
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
if lb.pool.Size() == 0 {
|
||||
lb.task.Finish("no server left")
|
||||
logger.Infof("[remove] loadbalancer %s stopped", lb.Link)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) IsEmpty() bool {
|
||||
return len(lb.pool) == 0
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Rebalance() {
|
||||
func (lb *LoadBalancer) rebalance() {
|
||||
if lb.sumWeight == maxWeight {
|
||||
return
|
||||
}
|
||||
if lb.pool.Size() == 0 {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / weightType(len(lb.pool))
|
||||
remainder := maxWeight % weightType(len(lb.pool))
|
||||
for _, s := range lb.pool {
|
||||
weightEach := maxWeight / weightType(lb.pool.Size())
|
||||
remainder := maxWeight % weightType(lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, s *Server) {
|
||||
s.Weight = weightEach
|
||||
lb.sumWeight += weightEach
|
||||
if remainder > 0 {
|
||||
s.Weight++
|
||||
remainder--
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -147,18 +179,18 @@ func (lb *LoadBalancer) Rebalance() {
|
|||
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
||||
lb.sumWeight = 0
|
||||
|
||||
for _, s := range lb.pool {
|
||||
lb.pool.RangeAll(func(_ string, s *Server) {
|
||||
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
||||
lb.sumWeight += s.Weight
|
||||
}
|
||||
})
|
||||
|
||||
delta := maxWeight - lb.sumWeight
|
||||
if delta == 0 {
|
||||
return
|
||||
}
|
||||
for _, s := range lb.pool {
|
||||
lb.pool.Range(func(_ string, s *Server) bool {
|
||||
if delta == 0 {
|
||||
break
|
||||
return false
|
||||
}
|
||||
if delta > 0 {
|
||||
s.Weight++
|
||||
|
@ -169,7 +201,8 @@ func (lb *LoadBalancer) Rebalance() {
|
|||
lb.sumWeight--
|
||||
delta++
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -181,23 +214,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Start() {
|
||||
if lb.sumWeight != 0 {
|
||||
log.Warnf("weighted mode not supported yet")
|
||||
}
|
||||
|
||||
lb.startTime = time.Now()
|
||||
logger.Debugf("loadbalancer %s started", lb.Link)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Stop() {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
lb.pool = nil
|
||||
|
||||
logger.Debugf("loadbalancer %s stopped", lb.Link)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Uptime() time.Duration {
|
||||
return time.Since(lb.startTime)
|
||||
}
|
||||
|
@ -205,9 +221,10 @@ func (lb *LoadBalancer) Uptime() time.Duration {
|
|||
// MarshalJSON implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
|
||||
extra := make(map[string]any)
|
||||
for _, v := range lb.pool {
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
extra[v.Name] = v.healthMon
|
||||
}
|
||||
})
|
||||
|
||||
return (&health.JSONRepresentation{
|
||||
Name: lb.Name(),
|
||||
Status: lb.Status(),
|
||||
|
@ -227,7 +244,7 @@ func (lb *LoadBalancer) Name() string {
|
|||
|
||||
// Status implements health.HealthMonitor.
|
||||
func (lb *LoadBalancer) Status() health.Status {
|
||||
if len(lb.pool) == 0 {
|
||||
if lb.pool.Size() == 0 {
|
||||
return health.StatusUnknown
|
||||
}
|
||||
if len(lb.availServers()) == 0 {
|
||||
|
@ -241,21 +258,13 @@ func (lb *LoadBalancer) String() string {
|
|||
return lb.Name()
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) availServers() servers {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
avail := make(servers, 0, len(lb.pool))
|
||||
for _, s := range lb.pool {
|
||||
if s.Status().Bad() {
|
||||
continue
|
||||
func (lb *LoadBalancer) availServers() []*Server {
|
||||
avail := make([]*Server, 0, lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, srv *Server) {
|
||||
if srv.Status().Bad() {
|
||||
return
|
||||
}
|
||||
avail = append(avail, s)
|
||||
}
|
||||
avail = append(avail, srv)
|
||||
})
|
||||
return avail
|
||||
}
|
||||
|
||||
// static HealthMonitor interface check
|
||||
func (lb *LoadBalancer) _() health.HealthMonitor {
|
||||
return lb
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ func TestRebalance(t *testing.T) {
|
|||
for range 10 {
|
||||
lb.AddServer(&Server{})
|
||||
}
|
||||
lb.Rebalance()
|
||||
lb.rebalance()
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
|
@ -23,7 +23,7 @@ func TestRebalance(t *testing.T) {
|
|||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.Rebalance()
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
|
@ -36,7 +36,7 @@ func TestRebalance(t *testing.T) {
|
|||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.Rebalance()
|
||||
lb.rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
|
@ -20,9 +21,12 @@ type (
|
|||
handler http.Handler
|
||||
healthMon health.HealthMonitor
|
||||
}
|
||||
servers []*Server
|
||||
servers = []*Server
|
||||
Pool = F.Map[string, *Server]
|
||||
)
|
||||
|
||||
var newPool = F.NewMap[Pool]
|
||||
|
||||
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
|
||||
srv := &Server{
|
||||
Name: name,
|
||||
|
|
|
@ -48,11 +48,11 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
|||
}
|
||||
delete(def, "use")
|
||||
m, err := base.WithOptionsClone(def)
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("item%d", i))
|
||||
continue
|
||||
}
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
chain = append(chain, m)
|
||||
}
|
||||
if chainErr.HasError() {
|
||||
|
|
19
internal/net/types/stream.go
Normal file
19
internal/net/types/stream.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Stream interface {
|
||||
fmt.Stringer
|
||||
Setup() error
|
||||
Accept() (conn StreamConn, err error)
|
||||
Handle(conn StreamConn) error
|
||||
CloseListeners()
|
||||
}
|
||||
|
||||
type StreamConn interface {
|
||||
RemoteAddr() net.Addr
|
||||
Close() error
|
||||
}
|
|
@ -1,177 +0,0 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type (
|
||||
ReverseProxyEntry struct { // real model after validation
|
||||
Raw *types.RawEntry `json:"raw"`
|
||||
|
||||
Alias T.Alias `json:"alias,omitempty"`
|
||||
Scheme T.Scheme `json:"scheme,omitempty"`
|
||||
URL net.URL `json:"url,omitempty"`
|
||||
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
|
||||
PathPatterns T.PathPatterns `json:"path_patterns,omitempty"`
|
||||
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
|
||||
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
|
||||
Middlewares D.NestedLabelMap `json:"middlewares,omitempty"`
|
||||
|
||||
/* Docker only */
|
||||
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
|
||||
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
|
||||
StopMethod T.StopMethod `json:"stop_method,omitempty"`
|
||||
StopTimeout int `json:"stop_timeout,omitempty"`
|
||||
StopSignal T.Signal `json:"stop_signal,omitempty"`
|
||||
DockerHost string `json:"docker_host,omitempty"`
|
||||
ContainerName string `json:"container_name,omitempty"`
|
||||
ContainerID string `json:"container_id,omitempty"`
|
||||
ContainerRunning bool `json:"container_running,omitempty"`
|
||||
}
|
||||
StreamEntry struct {
|
||||
Raw *types.RawEntry `json:"raw"`
|
||||
|
||||
Alias T.Alias `json:"alias,omitempty"`
|
||||
Scheme T.StreamScheme `json:"scheme,omitempty"`
|
||||
Host T.Host `json:"host,omitempty"`
|
||||
Port T.StreamPort `json:"port,omitempty"`
|
||||
Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
|
||||
return rp.IdleTimeout > 0 && rp.IsDocker()
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) UseLoadBalance() bool {
|
||||
return rp.LoadBalance != nil && rp.LoadBalance.Link != ""
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) IsDocker() bool {
|
||||
return rp.DockerHost != ""
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) IsZeroPort() bool {
|
||||
return rp.URL.Port() == "0"
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) ShouldNotServe() bool {
|
||||
return rp.IsZeroPort() && !rp.UseIdleWatcher()
|
||||
}
|
||||
|
||||
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
|
||||
m.FillMissingFields()
|
||||
|
||||
scheme, err := T.NewScheme(m.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var entry any
|
||||
e := E.NewBuilder("error validating entry")
|
||||
if scheme.IsStream() {
|
||||
entry = validateStreamEntry(m, e)
|
||||
} else {
|
||||
entry = validateRPEntry(m, scheme, e)
|
||||
}
|
||||
if err := e.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
|
||||
var stopTimeOut time.Duration
|
||||
cont := m.Container
|
||||
if cont == nil {
|
||||
cont = D.DummyContainer
|
||||
}
|
||||
|
||||
host, err := T.ValidateHost(m.Host)
|
||||
b.Add(err)
|
||||
|
||||
port, err := T.ValidatePort(m.Port)
|
||||
b.Add(err)
|
||||
|
||||
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
|
||||
b.Add(err)
|
||||
|
||||
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
|
||||
b.Add(err)
|
||||
|
||||
idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout)
|
||||
b.Add(err)
|
||||
|
||||
wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout)
|
||||
b.Add(err)
|
||||
|
||||
stopMethod, err := T.ValidateStopMethod(cont.StopMethod)
|
||||
b.Add(err)
|
||||
|
||||
if stopMethod == T.StopMethodStop {
|
||||
stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout)
|
||||
b.Add(err)
|
||||
}
|
||||
|
||||
stopSignal, err := T.ValidateSignal(cont.StopSignal)
|
||||
b.Add(err)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ReverseProxyEntry{
|
||||
Raw: m,
|
||||
Alias: T.NewAlias(m.Alias),
|
||||
Scheme: s,
|
||||
URL: net.NewURL(url),
|
||||
NoTLSVerify: m.NoTLSVerify,
|
||||
PathPatterns: pathPatterns,
|
||||
HealthCheck: &m.HealthCheck,
|
||||
LoadBalance: &m.LoadBalance,
|
||||
Middlewares: m.Middlewares,
|
||||
IdleTimeout: idleTimeout,
|
||||
WakeTimeout: wakeTimeout,
|
||||
StopMethod: stopMethod,
|
||||
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
|
||||
StopSignal: stopSignal,
|
||||
DockerHost: cont.DockerHost,
|
||||
ContainerName: cont.ContainerName,
|
||||
ContainerID: cont.ContainerID,
|
||||
ContainerRunning: cont.Running,
|
||||
}
|
||||
}
|
||||
|
||||
func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry {
|
||||
host, err := T.ValidateHost(m.Host)
|
||||
b.Add(err)
|
||||
|
||||
port, err := T.ValidateStreamPort(m.Port)
|
||||
b.Add(err)
|
||||
|
||||
scheme, err := T.ValidateStreamScheme(m.Scheme)
|
||||
b.Add(err)
|
||||
|
||||
if b.HasError() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &StreamEntry{
|
||||
Raw: m,
|
||||
Alias: T.NewAlias(m.Alias),
|
||||
Scheme: *scheme,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Healthcheck: &m.HealthCheck,
|
||||
}
|
||||
}
|
68
internal/proxy/entry/entry.go
Normal file
68
internal/proxy/entry/entry.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package entry
|
||||
|
||||
import (
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type Entry interface {
|
||||
TargetName() string
|
||||
TargetURL() net.URL
|
||||
RawEntry() *RawEntry
|
||||
LoadBalanceConfig() *loadbalancer.Config
|
||||
HealthCheckConfig() *health.HealthCheckConfig
|
||||
IdlewatcherConfig() *idlewatcher.Config
|
||||
}
|
||||
|
||||
func ValidateEntry(m *RawEntry) (Entry, E.NestedError) {
|
||||
m.FillMissingFields()
|
||||
|
||||
scheme, err := T.NewScheme(m.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var entry Entry
|
||||
e := E.NewBuilder("error validating entry")
|
||||
if scheme.IsStream() {
|
||||
entry = validateStreamEntry(m, e)
|
||||
} else {
|
||||
entry = validateRPEntry(m, scheme, e)
|
||||
}
|
||||
if err := e.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func IsDocker(entry Entry) bool {
|
||||
iw := entry.IdlewatcherConfig()
|
||||
return iw != nil && iw.ContainerID != ""
|
||||
}
|
||||
|
||||
func IsZeroPort(entry Entry) bool {
|
||||
return entry.TargetURL().Port() == "0"
|
||||
}
|
||||
|
||||
func ShouldNotServe(entry Entry) bool {
|
||||
return IsZeroPort(entry) && !UseIdleWatcher(entry)
|
||||
}
|
||||
|
||||
func UseLoadBalance(entry Entry) bool {
|
||||
lb := entry.LoadBalanceConfig()
|
||||
return lb != nil && lb.Link != ""
|
||||
}
|
||||
|
||||
func UseIdleWatcher(entry Entry) bool {
|
||||
iw := entry.IdlewatcherConfig()
|
||||
return iw != nil && iw.IdleTimeout > 0
|
||||
}
|
||||
|
||||
func UseHealthCheck(entry Entry) bool {
|
||||
hc := entry.HealthCheckConfig()
|
||||
return hc != nil && !hc.Disabled
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package types
|
||||
package entry
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
@ -21,16 +21,16 @@ type (
|
|||
|
||||
// raw entry object before validation
|
||||
// loaded from docker labels or yaml file
|
||||
Alias string `json:"-" yaml:"-"`
|
||||
Scheme string `json:"scheme,omitempty" yaml:"scheme"`
|
||||
Host string `json:"host,omitempty" yaml:"host"`
|
||||
Port string `json:"port,omitempty" yaml:"port"`
|
||||
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
|
||||
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
|
||||
HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
|
||||
LoadBalance loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"`
|
||||
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
|
||||
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
|
||||
Alias string `json:"-" yaml:"-"`
|
||||
Scheme string `json:"scheme,omitempty" yaml:"scheme"`
|
||||
Host string `json:"host,omitempty" yaml:"host"`
|
||||
Port string `json:"port,omitempty" yaml:"port"`
|
||||
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
|
||||
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
|
||||
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
|
||||
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"`
|
||||
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
|
||||
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
|
||||
|
||||
/* Docker only */
|
||||
Container *docker.Container `json:"container,omitempty" yaml:"-"`
|
||||
|
@ -122,29 +122,41 @@ func (e *RawEntry) FillMissingFields() {
|
|||
}
|
||||
}
|
||||
|
||||
if e.HealthCheck.Interval == 0 {
|
||||
e.HealthCheck.Interval = common.HealthCheckIntervalDefault
|
||||
if e.HealthCheck == nil {
|
||||
e.HealthCheck = new(health.HealthCheckConfig)
|
||||
}
|
||||
if e.HealthCheck.Timeout == 0 {
|
||||
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
|
||||
|
||||
if e.HealthCheck.Disabled {
|
||||
e.HealthCheck = nil
|
||||
} else {
|
||||
if e.HealthCheck.Interval == 0 {
|
||||
e.HealthCheck.Interval = common.HealthCheckIntervalDefault
|
||||
}
|
||||
if e.HealthCheck.Timeout == 0 {
|
||||
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
|
||||
}
|
||||
}
|
||||
if cont.IdleTimeout == "" {
|
||||
cont.IdleTimeout = common.IdleTimeoutDefault
|
||||
}
|
||||
if cont.WakeTimeout == "" {
|
||||
cont.WakeTimeout = common.WakeTimeoutDefault
|
||||
}
|
||||
if cont.StopTimeout == "" {
|
||||
cont.StopTimeout = common.StopTimeoutDefault
|
||||
}
|
||||
if cont.StopMethod == "" {
|
||||
cont.StopMethod = common.StopMethodDefault
|
||||
|
||||
if cont.IdleTimeout != "" {
|
||||
if cont.WakeTimeout == "" {
|
||||
cont.WakeTimeout = common.WakeTimeoutDefault
|
||||
}
|
||||
if cont.StopTimeout == "" {
|
||||
cont.StopTimeout = common.StopTimeoutDefault
|
||||
}
|
||||
if cont.StopMethod == "" {
|
||||
cont.StopMethod = common.StopMethodDefault
|
||||
}
|
||||
}
|
||||
|
||||
e.Port = joinPorts(lp, pp, extra)
|
||||
|
||||
if e.Port == "" || e.Host == "" {
|
||||
e.Port = "0"
|
||||
if lp != "" {
|
||||
e.Port = lp + ":0"
|
||||
} else {
|
||||
e.Port = "0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
98
internal/proxy/entry/reverse_proxy.go
Normal file
98
internal/proxy/entry/reverse_proxy.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package entry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type ReverseProxyEntry struct { // real model after validation
|
||||
Raw *RawEntry `json:"raw"`
|
||||
|
||||
Alias fields.Alias `json:"alias,omitempty"`
|
||||
Scheme fields.Scheme `json:"scheme,omitempty"`
|
||||
URL net.URL `json:"url,omitempty"`
|
||||
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
|
||||
PathPatterns fields.PathPatterns `json:"path_patterns,omitempty"`
|
||||
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
|
||||
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
|
||||
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty"`
|
||||
|
||||
/* Docker only */
|
||||
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) TargetName() string {
|
||||
return string(rp.Alias)
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) TargetURL() net.URL {
|
||||
return rp.URL
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) RawEntry() *RawEntry {
|
||||
return rp.Raw
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) LoadBalanceConfig() *loadbalancer.Config {
|
||||
return rp.LoadBalance
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) HealthCheckConfig() *health.HealthCheckConfig {
|
||||
return rp.HealthCheck
|
||||
}
|
||||
|
||||
func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config {
|
||||
return rp.Idlewatcher
|
||||
}
|
||||
|
||||
func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry {
|
||||
cont := m.Container
|
||||
if cont == nil {
|
||||
cont = docker.DummyContainer
|
||||
}
|
||||
|
||||
lb := m.LoadBalance
|
||||
if lb != nil && lb.Link == "" {
|
||||
lb = nil
|
||||
}
|
||||
|
||||
host, err := fields.ValidateHost(m.Host)
|
||||
b.Add(err)
|
||||
|
||||
port, err := fields.ValidatePort(m.Port)
|
||||
b.Add(err)
|
||||
|
||||
pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns)
|
||||
b.Add(err)
|
||||
|
||||
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
|
||||
b.Add(err)
|
||||
|
||||
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
|
||||
b.Add(err)
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ReverseProxyEntry{
|
||||
Raw: m,
|
||||
Alias: fields.NewAlias(m.Alias),
|
||||
Scheme: s,
|
||||
URL: net.NewURL(url),
|
||||
NoTLSVerify: m.NoTLSVerify,
|
||||
PathPatterns: pathPatterns,
|
||||
HealthCheck: m.HealthCheck,
|
||||
LoadBalance: lb,
|
||||
Middlewares: m.Middlewares,
|
||||
Idlewatcher: idleWatcherCfg,
|
||||
}
|
||||
}
|
89
internal/proxy/entry/stream.go
Normal file
89
internal/proxy/entry/stream.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package entry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type StreamEntry struct {
|
||||
Raw *RawEntry `json:"raw"`
|
||||
|
||||
Alias fields.Alias `json:"alias,omitempty"`
|
||||
Scheme fields.StreamScheme `json:"scheme,omitempty"`
|
||||
URL net.URL `json:"url,omitempty"`
|
||||
Host fields.Host `json:"host,omitempty"`
|
||||
Port fields.StreamPort `json:"port,omitempty"`
|
||||
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
|
||||
|
||||
/* Docker only */
|
||||
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
|
||||
}
|
||||
|
||||
func (s *StreamEntry) TargetName() string {
|
||||
return string(s.Alias)
|
||||
}
|
||||
|
||||
func (s *StreamEntry) TargetURL() net.URL {
|
||||
return s.URL
|
||||
}
|
||||
|
||||
func (s *StreamEntry) RawEntry() *RawEntry {
|
||||
return s.Raw
|
||||
}
|
||||
|
||||
func (s *StreamEntry) LoadBalanceConfig() *loadbalancer.Config {
|
||||
// TODO: support stream load balance
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StreamEntry) HealthCheckConfig() *health.HealthCheckConfig {
|
||||
return s.HealthCheck
|
||||
}
|
||||
|
||||
func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config {
|
||||
return s.Idlewatcher
|
||||
}
|
||||
|
||||
func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry {
|
||||
cont := m.Container
|
||||
if cont == nil {
|
||||
cont = docker.DummyContainer
|
||||
}
|
||||
|
||||
host, err := fields.ValidateHost(m.Host)
|
||||
b.Add(err)
|
||||
|
||||
port, err := fields.ValidateStreamPort(m.Port)
|
||||
b.Add(err)
|
||||
|
||||
scheme, err := fields.ValidateStreamScheme(m.Scheme)
|
||||
b.Add(err)
|
||||
|
||||
url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort)))
|
||||
b.Add(err)
|
||||
|
||||
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
|
||||
b.Add(err)
|
||||
|
||||
if b.HasError() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &StreamEntry{
|
||||
Raw: m,
|
||||
Alias: fields.NewAlias(m.Alias),
|
||||
Scheme: *scheme,
|
||||
URL: url,
|
||||
Host: host,
|
||||
Port: port,
|
||||
HealthCheck: m.HealthCheck,
|
||||
Idlewatcher: idleWatcherCfg,
|
||||
}
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type Signal string
|
||||
|
||||
func ValidateSignal(s string) (Signal, E.NestedError) {
|
||||
switch s {
|
||||
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
|
||||
"INT", "TERM", "HUP", "QUIT":
|
||||
return Signal(s), nil
|
||||
}
|
||||
|
||||
return "", E.Invalid("signal", s)
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type StopMethod string
|
||||
|
||||
const (
|
||||
StopMethodPause StopMethod = "pause"
|
||||
StopMethodStop StopMethod = "stop"
|
||||
StopMethodKill StopMethod = "kill"
|
||||
)
|
||||
|
||||
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
|
||||
sm := StopMethod(s)
|
||||
switch sm {
|
||||
case StopMethodPause, StopMethodStop, StopMethodKill:
|
||||
return sm, nil
|
||||
default:
|
||||
return "", E.Invalid("stop_method", sm)
|
||||
}
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, E.Invalid("duration", value)
|
||||
}
|
||||
if d < 0 {
|
||||
return 0, E.Invalid("duration", "negative value")
|
||||
}
|
||||
return d, nil
|
||||
}
|
|
@ -9,23 +9,21 @@ import (
|
|||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
url "github.com/yusing/go-proxy/internal/net/types"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type (
|
||||
HTTPRoute struct {
|
||||
*P.ReverseProxyEntry
|
||||
*entry.ReverseProxyEntry
|
||||
|
||||
HealthMon health.HealthMonitor `json:"health,omitempty"`
|
||||
|
||||
|
@ -33,6 +31,8 @@ type (
|
|||
server *loadbalancer.Server
|
||||
handler http.Handler
|
||||
rp *gphttp.ReverseProxy
|
||||
|
||||
task task.Task
|
||||
}
|
||||
|
||||
SubdomainKey = PT.Alias
|
||||
|
@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
||||
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
|
||||
var trans *http.Transport
|
||||
|
||||
if entry.NoTLSVerify {
|
||||
|
@ -84,12 +84,10 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
|
|||
}
|
||||
}
|
||||
|
||||
httpRoutesMu.Lock()
|
||||
defer httpRoutesMu.Unlock()
|
||||
|
||||
r := &HTTPRoute{
|
||||
ReverseProxyEntry: entry,
|
||||
rp: rp,
|
||||
task: task.DummyTask(),
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
@ -98,39 +96,34 @@ func (r *HTTPRoute) String() string {
|
|||
return string(r.Alias)
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) URL() url.URL {
|
||||
return r.ReverseProxyEntry.URL
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) Start() E.NestedError {
|
||||
if r.ShouldNotServe() {
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
}
|
||||
|
||||
httpRoutesMu.Lock()
|
||||
defer httpRoutesMu.Unlock()
|
||||
|
||||
if r.handler != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) {
|
||||
if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
|
||||
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
|
||||
r.HealthCheck.Disabled = true
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.UseIdleWatcher():
|
||||
watcher, err := idlewatcher.Register(r.ReverseProxyEntry)
|
||||
case entry.UseIdleWatcher(r):
|
||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
|
||||
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
waker := idlewatcher.NewWaker(watcher, r.rp)
|
||||
r.handler = waker
|
||||
r.HealthMon = waker
|
||||
case !r.HealthCheck.Disabled:
|
||||
r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask(r.String()), r.URL(), r.HealthCheck)
|
||||
case entry.UseHealthCheck(r):
|
||||
r.HealthMon = health.NewHTTPHealthMonitor(r.TargetURL(), r.HealthCheck, r.rp.Transport)
|
||||
}
|
||||
r.task = providerSubtask
|
||||
|
||||
if r.handler == nil {
|
||||
switch {
|
||||
|
@ -146,44 +139,26 @@ func (r *HTTPRoute) Start() E.NestedError {
|
|||
}
|
||||
|
||||
if r.HealthMon != nil {
|
||||
r.HealthMon.Start()
|
||||
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
|
||||
logrus.Warn(E.FailWith("health monitor", err))
|
||||
}
|
||||
}
|
||||
|
||||
if r.UseLoadBalance() {
|
||||
if entry.UseLoadBalance(r) {
|
||||
r.addToLoadBalancer()
|
||||
} else {
|
||||
httpRoutes.Store(string(r.Alias), r)
|
||||
r.task.OnComplete("stop rp", func() {
|
||||
httpRoutes.Delete(string(r.Alias))
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) Stop() (_ E.NestedError) {
|
||||
if r.handler == nil {
|
||||
return
|
||||
}
|
||||
|
||||
httpRoutesMu.Lock()
|
||||
defer httpRoutesMu.Unlock()
|
||||
|
||||
if r.loadBalancer != nil {
|
||||
r.removeFromLoadBalancer()
|
||||
} else {
|
||||
httpRoutes.Delete(string(r.Alias))
|
||||
}
|
||||
|
||||
if r.HealthMon != nil {
|
||||
r.HealthMon.Stop()
|
||||
r.HealthMon = nil
|
||||
}
|
||||
|
||||
r.handler = nil
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) Started() bool {
|
||||
return r.handler != nil
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (r *HTTPRoute) Finish(reason string) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) addToLoadBalancer() {
|
||||
|
@ -197,10 +172,14 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
|||
}
|
||||
} else {
|
||||
lb = loadbalancer.New(r.LoadBalance)
|
||||
lb.Start()
|
||||
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
|
||||
lbTask.OnComplete("remove lb from routes", func() {
|
||||
httpRoutes.Delete(r.LoadBalance.Link)
|
||||
})
|
||||
lb.Start(lbTask)
|
||||
linked = &HTTPRoute{
|
||||
ReverseProxyEntry: &P.ReverseProxyEntry{
|
||||
Raw: &types.RawEntry{
|
||||
ReverseProxyEntry: &entry.ReverseProxyEntry{
|
||||
Raw: &entry.RawEntry{
|
||||
Homepage: r.Raw.Homepage,
|
||||
},
|
||||
Alias: PT.Alias(lb.Link),
|
||||
|
@ -214,16 +193,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
|||
r.loadBalancer = lb
|
||||
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
lb.AddServer(r.server)
|
||||
}
|
||||
|
||||
func (r *HTTPRoute) removeFromLoadBalancer() {
|
||||
r.loadBalancer.RemoveServer(r.server)
|
||||
if r.loadBalancer.IsEmpty() {
|
||||
httpRoutes.Delete(r.LoadBalance.Link)
|
||||
logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link)
|
||||
}
|
||||
r.server = nil
|
||||
r.loadBalancer = nil
|
||||
r.task.OnComplete("remove server from lb", func() {
|
||||
lb.RemoveServer(r.server)
|
||||
})
|
||||
}
|
||||
|
||||
func ProxyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -10,10 +10,9 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
type DockerProvider struct {
|
||||
|
@ -43,7 +42,7 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
|
|||
|
||||
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
|
||||
routes = R.NewRoutes()
|
||||
entries := types.NewProxyEntries()
|
||||
entries := entry.NewProxyEntries()
|
||||
|
||||
info, err := D.GetClientInfo(p.dockerHost, true)
|
||||
if err != nil {
|
||||
|
@ -66,12 +65,12 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
|
|||
// there may be some valid entries in `en`
|
||||
dups := entries.MergeFrom(newEntries)
|
||||
// add the duplicate proxy entries to the error
|
||||
dups.RangeAll(func(k string, v *types.RawEntry) {
|
||||
dups.RangeAll(func(k string, v *entry.RawEntry) {
|
||||
errors.Addf("duplicate alias %s", k)
|
||||
})
|
||||
}
|
||||
|
||||
entries.RangeAll(func(_ string, e *types.RawEntry) {
|
||||
entries.RangeAll(func(_ string, e *entry.RawEntry) {
|
||||
e.Container.DockerHost = p.dockerHost
|
||||
})
|
||||
|
||||
|
@ -88,85 +87,10 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
|
|||
strings.HasSuffix(container.ContainerName, "-old")
|
||||
}
|
||||
|
||||
func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) {
|
||||
b := E.NewBuilder("event %s error", event)
|
||||
defer b.To(&res.err)
|
||||
|
||||
matches := R.NewRoutes()
|
||||
oldRoutes.RangeAllParallel(func(k string, v *R.Route) {
|
||||
if v.Entry.Container.ContainerID == event.ActorID ||
|
||||
v.Entry.Container.ContainerName == event.ActorName {
|
||||
matches.Store(k, v)
|
||||
}
|
||||
})
|
||||
//FIXME: docker event die stuck
|
||||
|
||||
var newRoutes R.Routes
|
||||
var err E.NestedError
|
||||
|
||||
switch {
|
||||
// id & container name changed
|
||||
case matches.Size() == 0:
|
||||
matches = oldRoutes
|
||||
newRoutes, err = p.LoadRoutesImpl()
|
||||
b.Add(err)
|
||||
case event.Action == events.ActionContainerDestroy:
|
||||
// stop all old routes
|
||||
matches.RangeAllParallel(func(_ string, v *R.Route) {
|
||||
oldRoutes.Delete(v.Entry.Alias)
|
||||
b.Add(v.Stop())
|
||||
res.nRemoved++
|
||||
})
|
||||
return
|
||||
default:
|
||||
cont, err := D.Inspect(p.dockerHost, event.ActorID)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("inspect container", err))
|
||||
return
|
||||
}
|
||||
|
||||
if p.shouldIgnore(cont) {
|
||||
// stop all old routes
|
||||
matches.RangeAllParallel(func(_ string, v *R.Route) {
|
||||
b.Add(v.Stop())
|
||||
res.nRemoved++
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := p.entriesFromContainerLabels(cont)
|
||||
b.Add(err)
|
||||
newRoutes, err = R.FromEntries(entries)
|
||||
b.Add(err)
|
||||
}
|
||||
|
||||
matches.RangeAll(func(k string, v *R.Route) {
|
||||
if !newRoutes.Has(k) && !oldRoutes.Has(k) {
|
||||
b.Add(v.Stop())
|
||||
matches.Delete(k)
|
||||
res.nRemoved++
|
||||
}
|
||||
})
|
||||
|
||||
newRoutes.RangeAll(func(alias string, newRoute *R.Route) {
|
||||
oldRoute, exists := oldRoutes.Load(alias)
|
||||
if exists {
|
||||
b.Add(oldRoute.Stop())
|
||||
res.nReloaded++
|
||||
} else {
|
||||
res.nAdded++
|
||||
}
|
||||
b.Add(newRoute.Start())
|
||||
oldRoutes.Store(alias, newRoute)
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Returns a list of proxy entries for a container.
|
||||
// Always non-nil.
|
||||
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries types.RawEntries, _ E.NestedError) {
|
||||
entries = types.NewProxyEntries()
|
||||
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
|
||||
entries = entry.NewProxyEntries()
|
||||
|
||||
if p.shouldIgnore(container) {
|
||||
return
|
||||
|
@ -174,7 +98,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
|
|||
|
||||
// init entries map for all aliases
|
||||
for _, a := range container.Aliases {
|
||||
entries.Store(a, &types.RawEntry{
|
||||
entries.Store(a, &entry.RawEntry{
|
||||
Alias: a,
|
||||
Container: container,
|
||||
})
|
||||
|
@ -186,14 +110,14 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
|
|||
}
|
||||
|
||||
// remove all entries that failed to fill in missing fields
|
||||
entries.RangeAll(func(_ string, re *types.RawEntry) {
|
||||
entries.RangeAll(func(_ string, re *entry.RawEntry) {
|
||||
re.FillMissingFields()
|
||||
})
|
||||
|
||||
return entries, errors.Build().Subject(container.ContainerName)
|
||||
}
|
||||
|
||||
func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEntries, key, val string) (res E.NestedError) {
|
||||
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
|
||||
b := E.NewBuilder("errors in label %s", key)
|
||||
defer b.To(&res)
|
||||
|
||||
|
@ -220,7 +144,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt
|
|||
}
|
||||
if lbl.Target == D.WildcardAlias {
|
||||
// apply label for all aliases
|
||||
entries.RangeAll(func(a string, e *types.RawEntry) {
|
||||
entries.RangeAll(func(a string, e *entry.RawEntry) {
|
||||
if err = D.ApplyLabel(e, lbl); err != nil {
|
||||
b.Add(err)
|
||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
@ -46,7 +46,7 @@ func TestApplyLabelWildcard(t *testing.T) {
|
|||
Names: dummyNames,
|
||||
Labels: map[string]string{
|
||||
D.LabelAliases: "a,b",
|
||||
D.LabelIdleTimeout: common.IdleTimeoutDefault,
|
||||
D.LabelIdleTimeout: "",
|
||||
D.LabelStopMethod: common.StopMethodDefault,
|
||||
D.LabelStopSignal: "SIGTERM",
|
||||
D.LabelStopTimeout: common.StopTimeoutDefault,
|
||||
|
@ -62,7 +62,7 @@ func TestApplyLabelWildcard(t *testing.T) {
|
|||
"proxy.a.middlewares.middleware2.prop3": "value3",
|
||||
"proxy.a.middlewares.middleware2.prop4": "value4",
|
||||
},
|
||||
}, ""))
|
||||
}, client.DefaultDockerHost))
|
||||
ExpectNoError(t, err.Error())
|
||||
|
||||
a, ok := entries.Load("a")
|
||||
|
@ -88,8 +88,8 @@ func TestApplyLabelWildcard(t *testing.T) {
|
|||
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
|
||||
ExpectEqual(t, len(b.Middlewares), 0)
|
||||
|
||||
ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault)
|
||||
ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault)
|
||||
ExpectEqual(t, a.Container.IdleTimeout, "")
|
||||
ExpectEqual(t, b.Container.IdleTimeout, "")
|
||||
|
||||
ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault)
|
||||
ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault)
|
||||
|
@ -107,6 +107,7 @@ func TestApplyLabelWildcard(t *testing.T) {
|
|||
func TestApplyLabelWithAlias(t *testing.T) {
|
||||
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
|
||||
Names: dummyNames,
|
||||
State: "running",
|
||||
Labels: map[string]string{
|
||||
D.LabelAliases: "a,b,c",
|
||||
"proxy.a.no_tls_verify": "true",
|
||||
|
@ -114,7 +115,7 @@ func TestApplyLabelWithAlias(t *testing.T) {
|
|||
"proxy.b.port": "1234",
|
||||
"proxy.c.scheme": "https",
|
||||
},
|
||||
}, ""))
|
||||
}, client.DefaultDockerHost))
|
||||
a, ok := entries.Load("a")
|
||||
ExpectTrue(t, ok)
|
||||
b, ok := entries.Load("b")
|
||||
|
@ -134,6 +135,7 @@ func TestApplyLabelWithAlias(t *testing.T) {
|
|||
func TestApplyLabelWithRef(t *testing.T) {
|
||||
entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{
|
||||
Names: dummyNames,
|
||||
State: "running",
|
||||
Labels: map[string]string{
|
||||
D.LabelAliases: "a,b,c",
|
||||
"proxy.#1.host": "localhost",
|
||||
|
@ -142,7 +144,7 @@ func TestApplyLabelWithRef(t *testing.T) {
|
|||
"proxy.#3.port": "1111",
|
||||
"proxy.#3.scheme": "https",
|
||||
},
|
||||
}, "")))
|
||||
}, client.DefaultDockerHost)))
|
||||
a, ok := entries.Load("a")
|
||||
ExpectTrue(t, ok)
|
||||
b, ok := entries.Load("b")
|
||||
|
@ -161,6 +163,7 @@ func TestApplyLabelWithRef(t *testing.T) {
|
|||
func TestApplyLabelWithRefIndexError(t *testing.T) {
|
||||
c := D.FromDocker(&types.Container{
|
||||
Names: dummyNames,
|
||||
State: "running",
|
||||
Labels: map[string]string{
|
||||
D.LabelAliases: "a,b",
|
||||
"proxy.#1.host": "localhost",
|
||||
|
@ -173,6 +176,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
|
|||
|
||||
_, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{
|
||||
Names: dummyNames,
|
||||
State: "running",
|
||||
Labels: map[string]string{
|
||||
D.LabelAliases: "a,b",
|
||||
"proxy.#0.host": "localhost",
|
||||
|
@ -183,7 +187,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPublicIPLocalhost(t *testing.T) {
|
||||
c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost)
|
||||
c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, client.DefaultDockerHost)
|
||||
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
|
||||
ExpectTrue(t, ok)
|
||||
ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1")
|
||||
|
@ -191,7 +195,7 @@ func TestPublicIPLocalhost(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPublicIPRemote(t *testing.T) {
|
||||
c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375")
|
||||
c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, "tcp://1.2.3.4:2375")
|
||||
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
|
||||
ExpectTrue(t, ok)
|
||||
ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
|
||||
|
@ -218,6 +222,7 @@ func TestPrivateIPLocalhost(t *testing.T) {
|
|||
func TestPrivateIPRemote(t *testing.T) {
|
||||
c := D.FromDocker(&types.Container{
|
||||
Names: dummyNames,
|
||||
State: "running",
|
||||
NetworkSettings: &types.SummaryNetworkSettings{
|
||||
Networks: map[string]*network.EndpointSettings{
|
||||
"network": {
|
||||
|
@ -239,6 +244,7 @@ func TestStreamDefaultValues(t *testing.T) {
|
|||
privIP := "172.17.0.123"
|
||||
cont := &types.Container{
|
||||
Names: []string{"a"},
|
||||
State: "running",
|
||||
NetworkSettings: &types.SummaryNetworkSettings{
|
||||
Networks: map[string]*network.EndpointSettings{
|
||||
"network": {
|
||||
|
@ -256,9 +262,8 @@ func TestStreamDefaultValues(t *testing.T) {
|
|||
|
||||
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
|
||||
ExpectTrue(t, ok)
|
||||
entry := Must(P.ValidateEntry(raw))
|
||||
|
||||
a := ExpectType[*P.StreamEntry](t, entry)
|
||||
en := Must(entry.ValidateEntry(raw))
|
||||
a := ExpectType[*entry.StreamEntry](t, en)
|
||||
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
|
||||
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
|
||||
ExpectEqual(t, a.Host, T.Host(privIP))
|
||||
|
@ -270,9 +275,8 @@ func TestStreamDefaultValues(t *testing.T) {
|
|||
c := D.FromDocker(cont, "tcp://1.2.3.4:2375")
|
||||
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
|
||||
ExpectTrue(t, ok)
|
||||
entry := Must(P.ValidateEntry(raw))
|
||||
|
||||
a := ExpectType[*P.StreamEntry](t, entry)
|
||||
en := Must(entry.ValidateEntry(raw))
|
||||
a := ExpectType[*entry.StreamEntry](t, en)
|
||||
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
|
||||
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
|
||||
ExpectEqual(t, a.Host, "1.2.3.4")
|
109
internal/route/provider/event_handler.go
Normal file
109
internal/route/provider/event_handler.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
)
|
||||
|
||||
type EventHandler struct {
|
||||
provider *Provider
|
||||
|
||||
added []string
|
||||
removed []string
|
||||
paused []string
|
||||
updated []string
|
||||
errs E.Builder
|
||||
}
|
||||
|
||||
func (provider *Provider) newEventHandler() *EventHandler {
|
||||
return &EventHandler{
|
||||
provider: provider,
|
||||
errs: E.NewBuilder("event errors"),
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
|
||||
oldRoutes := handler.provider.routes
|
||||
newRoutes, err := handler.provider.LoadRoutesImpl()
|
||||
if err != nil {
|
||||
handler.errs.Add(err.Subject("load routes"))
|
||||
return
|
||||
}
|
||||
|
||||
oldRoutes.RangeAll(func(k string, v *route.Route) {
|
||||
if !newRoutes.Has(k) {
|
||||
handler.Remove(v)
|
||||
}
|
||||
})
|
||||
newRoutes.RangeAll(func(k string, newr *route.Route) {
|
||||
if oldRoutes.Has(k) {
|
||||
for _, ev := range events {
|
||||
if handler.match(ev, newr) {
|
||||
old, ok := oldRoutes.Load(k)
|
||||
if !ok { // should not happen
|
||||
panic("race condition")
|
||||
}
|
||||
handler.Update(parent, old, newr)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handler.Add(parent, newr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
|
||||
switch handler.provider.t {
|
||||
case ProviderTypeDocker:
|
||||
return route.Entry.Container.ContainerID == event.ActorID ||
|
||||
route.Entry.Container.ContainerName == event.ActorName
|
||||
case ProviderTypeFile:
|
||||
return true
|
||||
}
|
||||
// should never happen
|
||||
return false
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
|
||||
err := handler.provider.startRoute(parent, route)
|
||||
if err != nil {
|
||||
handler.errs.Add(err)
|
||||
} else {
|
||||
handler.added = append(handler.added, route.Entry.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Remove(route *route.Route) {
|
||||
route.Finish("route removal")
|
||||
handler.removed = append(handler.removed, route.Entry.Alias)
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) {
|
||||
oldRoute.Finish("route update")
|
||||
err := handler.provider.startRoute(parent, newRoute)
|
||||
if err != nil {
|
||||
handler.errs.Add(err)
|
||||
} else {
|
||||
handler.updated = append(handler.updated, newRoute.Entry.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Log() {
|
||||
results := E.NewBuilder("event occured")
|
||||
for _, alias := range handler.added {
|
||||
results.Addf("added %s", alias)
|
||||
}
|
||||
for _, alias := range handler.removed {
|
||||
results.Addf("removed %s", alias)
|
||||
}
|
||||
for _, alias := range handler.updated {
|
||||
results.Addf("updated %s", alias)
|
||||
}
|
||||
results.Add(handler.errs.Build())
|
||||
if result := results.Build(); result != nil {
|
||||
handler.provider.l.Info(result)
|
||||
}
|
||||
}
|
|
@ -7,8 +7,8 @@ import (
|
|||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
)
|
||||
|
@ -42,38 +42,13 @@ func (p FileProvider) String() string {
|
|||
return p.fileName
|
||||
}
|
||||
|
||||
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
|
||||
b := E.NewBuilder("event %s error", event)
|
||||
defer b.To(&res.err)
|
||||
|
||||
newRoutes, err := p.LoadRoutesImpl()
|
||||
if err != nil {
|
||||
b.Add(err)
|
||||
return
|
||||
}
|
||||
|
||||
res.nRemoved = newRoutes.Size()
|
||||
routes.RangeAllParallel(func(_ string, v *R.Route) {
|
||||
b.Add(v.Stop())
|
||||
})
|
||||
routes.Clear()
|
||||
|
||||
newRoutes.RangeAllParallel(func(_ string, v *R.Route) {
|
||||
b.Add(v.Start())
|
||||
})
|
||||
res.nAdded = newRoutes.Size()
|
||||
|
||||
routes.MergeFrom(newRoutes)
|
||||
return
|
||||
}
|
||||
|
||||
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
|
||||
routes = R.NewRoutes()
|
||||
|
||||
b := E.NewBuilder("file %q validation failure", p.fileName)
|
||||
defer b.To(&res)
|
||||
|
||||
entries := types.NewProxyEntries()
|
||||
entries := entry.NewProxyEntries()
|
||||
|
||||
data, err := E.Check(os.ReadFile(p.path))
|
||||
if err != nil {
|
|
@ -1,14 +1,16 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -19,18 +21,14 @@ type (
|
|||
t ProviderType
|
||||
routes R.Routes
|
||||
|
||||
watcher W.Watcher
|
||||
watcherTask common.Task
|
||||
watcherCancel context.CancelFunc
|
||||
watcher W.Watcher
|
||||
|
||||
l *logrus.Entry
|
||||
}
|
||||
ProviderImpl interface {
|
||||
fmt.Stringer
|
||||
NewWatcher() W.Watcher
|
||||
// even returns error, routes must be non-nil
|
||||
LoadRoutesImpl() (R.Routes, E.NestedError)
|
||||
OnEvent(event W.Event, routes R.Routes) EventResult
|
||||
String() string
|
||||
}
|
||||
ProviderType string
|
||||
ProviderStats struct {
|
||||
|
@ -38,17 +36,13 @@ type (
|
|||
NumStreams int `json:"num_streams"`
|
||||
Type ProviderType `json:"type"`
|
||||
}
|
||||
EventResult struct {
|
||||
nAdded int
|
||||
nRemoved int
|
||||
nReloaded int
|
||||
err E.NestedError
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderTypeDocker ProviderType = "docker"
|
||||
ProviderTypeFile ProviderType = "file"
|
||||
|
||||
providerEventFlushInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func newProvider(name string, t ProviderType) *Provider {
|
||||
|
@ -106,32 +100,48 @@ func (p *Provider) MarshalText() ([]byte, error) {
|
|||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
func (p *Provider) StartAllRoutes() (res E.NestedError) {
|
||||
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
|
||||
subtask := parent.Subtask("route %s", r.Entry.Alias)
|
||||
err := r.Start(subtask)
|
||||
if err != nil {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
subtask.Finish(err.String()) // just to ensure
|
||||
return err
|
||||
} else {
|
||||
subtask.OnComplete("del from provider", func() {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
|
||||
errors := E.NewBuilder("errors starting routes")
|
||||
defer errors.To(&res)
|
||||
|
||||
// start watcher no matter load success or not
|
||||
go p.watchEvents()
|
||||
// routes and event queue will stop on parent cancel
|
||||
providerTask := configSubtask
|
||||
|
||||
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
|
||||
errors.Add(r.Start().Subject(r))
|
||||
errors.Add(p.startRoute(providerTask, r))
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Provider) StopAllRoutes() (res E.NestedError) {
|
||||
if p.watcherCancel != nil {
|
||||
p.watcherCancel()
|
||||
p.watcherCancel = nil
|
||||
}
|
||||
|
||||
errors := E.NewBuilder("errors stopping routes")
|
||||
defer errors.To(&res)
|
||||
|
||||
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
|
||||
errors.Add(r.Stop().Subject(r))
|
||||
})
|
||||
p.routes.Clear()
|
||||
eventQueue := events.NewEventQueue(
|
||||
providerTask,
|
||||
providerEventFlushInterval,
|
||||
func(flushTask task.Task, events []events.Event) {
|
||||
handler := p.newEventHandler()
|
||||
// routes' lifetime should follow the provider's lifetime
|
||||
handler.Handle(providerTask, events)
|
||||
handler.Log()
|
||||
flushTask.Finish("events flushed")
|
||||
},
|
||||
func(err E.NestedError) {
|
||||
p.l.Error(err)
|
||||
},
|
||||
)
|
||||
eventQueue.Start(p.watcher.Events(providerTask.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -147,7 +157,6 @@ func (p *Provider) LoadRoutes() E.NestedError {
|
|||
var err E.NestedError
|
||||
p.routes, err = p.LoadRoutesImpl()
|
||||
if p.routes.Size() > 0 {
|
||||
p.l.Infof("loaded %d routes", p.routes.Size())
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
|
@ -156,13 +165,14 @@ func (p *Provider) LoadRoutes() E.NestedError {
|
|||
return E.FailWith("loading routes", err)
|
||||
}
|
||||
|
||||
func (p *Provider) NumRoutes() int {
|
||||
return p.routes.Size()
|
||||
}
|
||||
|
||||
func (p *Provider) Statistics() ProviderStats {
|
||||
numRPs := 0
|
||||
numStreams := 0
|
||||
p.routes.RangeAll(func(_ string, r *R.Route) {
|
||||
if !r.Started() {
|
||||
return
|
||||
}
|
||||
switch r.Type {
|
||||
case R.RouteTypeReverseProxy:
|
||||
numRPs++
|
||||
|
@ -176,34 +186,3 @@ func (p *Provider) Statistics() ProviderStats {
|
|||
Type: p.t,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) watchEvents() {
|
||||
p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name)
|
||||
defer p.watcherTask.Finished()
|
||||
|
||||
events, errs := p.watcher.Events(p.watcherTask.Context())
|
||||
l := p.l.WithField("module", "watcher")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.watcherTask.Context().Done():
|
||||
return
|
||||
case event := <-events:
|
||||
task := p.watcherTask.Subtask("%s event %s", event.Type, event)
|
||||
l.Infof("%s event %q", event.Type, event)
|
||||
res := p.OnEvent(event, p.routes)
|
||||
task.Finished()
|
||||
if res.nAdded+res.nRemoved+res.nReloaded > 0 {
|
||||
l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded)
|
||||
}
|
||||
if res.err != nil {
|
||||
l.Error(res.err)
|
||||
}
|
||||
case err := <-errs:
|
||||
if err == nil || err.Is(context.Canceled) {
|
||||
continue
|
||||
}
|
||||
l.Errorf("watcher error: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,8 +4,8 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
url "github.com/yusing/go-proxy/internal/net/types"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
@ -16,16 +16,16 @@ type (
|
|||
_ U.NoCopy
|
||||
impl
|
||||
Type RouteType
|
||||
Entry *types.RawEntry
|
||||
Entry *entry.RawEntry
|
||||
}
|
||||
Routes = F.Map[string, *Route]
|
||||
|
||||
impl interface {
|
||||
Start() E.NestedError
|
||||
Stop() E.NestedError
|
||||
Started() bool
|
||||
entry.Entry
|
||||
task.TaskStarter
|
||||
task.TaskFinisher
|
||||
String() string
|
||||
URL() url.URL
|
||||
TargetURL() url.URL
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -44,8 +44,8 @@ func (rt *Route) Container() *docker.Container {
|
|||
return rt.Entry.Container
|
||||
}
|
||||
|
||||
func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
|
||||
entry, err := P.ValidateEntry(en)
|
||||
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
|
||||
en, err := entry.ValidateEntry(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -53,11 +53,11 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
|
|||
var t RouteType
|
||||
var rt impl
|
||||
|
||||
switch e := entry.(type) {
|
||||
case *P.StreamEntry:
|
||||
switch e := en.(type) {
|
||||
case *entry.StreamEntry:
|
||||
t = RouteTypeStream
|
||||
rt, err = NewStreamRoute(e)
|
||||
case *P.ReverseProxyEntry:
|
||||
case *entry.ReverseProxyEntry:
|
||||
t = RouteTypeReverseProxy
|
||||
rt, err = NewHTTPRoute(e)
|
||||
default:
|
||||
|
@ -69,19 +69,21 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
|
|||
return &Route{
|
||||
impl: rt,
|
||||
Type: t,
|
||||
Entry: en,
|
||||
Entry: raw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func FromEntries(entries types.RawEntries) (Routes, E.NestedError) {
|
||||
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
|
||||
b := E.NewBuilder("errors in routes")
|
||||
|
||||
routes := NewRoutes()
|
||||
entries.RangeAll(func(alias string, entry *types.RawEntry) {
|
||||
entry.Alias = alias
|
||||
r, err := NewRoute(entry)
|
||||
entries.RangeAllParallel(func(alias string, en *entry.RawEntry) {
|
||||
en.Alias = alias
|
||||
r, err := NewRoute(en)
|
||||
if err != nil {
|
||||
b.Add(err.Subject(alias))
|
||||
} else if entry.ShouldNotServe(r) {
|
||||
return
|
||||
} else {
|
||||
routes.Store(alias, r)
|
||||
}
|
||||
|
|
|
@ -4,169 +4,141 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
stdNet "net"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
url "github.com/yusing/go-proxy/internal/net/types"
|
||||
P "github.com/yusing/go-proxy/internal/proxy"
|
||||
PT "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type StreamRoute struct {
|
||||
*P.StreamEntry
|
||||
StreamImpl `json:"-"`
|
||||
*entry.StreamEntry
|
||||
net.Stream `json:"-"`
|
||||
|
||||
HealthMon health.HealthMonitor `json:"health"`
|
||||
|
||||
url url.URL
|
||||
|
||||
task common.Task
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
task task.Task
|
||||
|
||||
l logrus.FieldLogger
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type StreamImpl interface {
|
||||
Setup() error
|
||||
Accept() (any, error)
|
||||
Handle(conn any) error
|
||||
CloseListeners()
|
||||
String() string
|
||||
}
|
||||
|
||||
var streamRoutes = F.NewMapOf[string, *StreamRoute]()
|
||||
var (
|
||||
streamRoutes = F.NewMapOf[string, *StreamRoute]()
|
||||
streamRoutesMu sync.Mutex
|
||||
)
|
||||
|
||||
func GetStreamProxies() F.Map[string, *StreamRoute] {
|
||||
return streamRoutes
|
||||
}
|
||||
|
||||
func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
|
||||
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
|
||||
// TODO: support non-coherent scheme
|
||||
if !entry.Scheme.IsCoherent() {
|
||||
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
|
||||
}
|
||||
url, err := url.ParseURL(fmt.Sprintf("%s://%s:%d", entry.Scheme.ProxyScheme, entry.Host, entry.Port.ProxyPort))
|
||||
if err != nil {
|
||||
// !! should not happen
|
||||
panic(err)
|
||||
}
|
||||
base := &StreamRoute{
|
||||
return &StreamRoute{
|
||||
StreamEntry: entry,
|
||||
url: url,
|
||||
}
|
||||
if entry.Scheme.ListeningScheme.IsTCP() {
|
||||
base.StreamImpl = NewTCPRoute(base)
|
||||
} else {
|
||||
base.StreamImpl = NewUDPRoute(base)
|
||||
}
|
||||
base.l = logrus.WithField("route", base.StreamImpl)
|
||||
return base, nil
|
||||
task: task.DummyTask(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Finish(reason string) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) String() string {
|
||||
return fmt.Sprintf("stream %s", r.Alias)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) URL() url.URL {
|
||||
return r.url
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Start() E.NestedError {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.Port.ProxyPort == PT.NoPort || r.task != nil {
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
}
|
||||
r.task, r.cancel = common.NewTaskWithCancel(r.String())
|
||||
|
||||
streamRoutesMu.Lock()
|
||||
defer streamRoutesMu.Unlock()
|
||||
|
||||
if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
|
||||
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
|
||||
r.HealthCheck.Disabled = true
|
||||
}
|
||||
|
||||
if r.Scheme.ListeningScheme.IsTCP() {
|
||||
r.Stream = NewTCPRoute(r)
|
||||
} else {
|
||||
r.Stream = NewUDPRoute(r)
|
||||
}
|
||||
r.l = logrus.WithField("route", r.Stream.String())
|
||||
|
||||
switch {
|
||||
case entry.UseIdleWatcher(r):
|
||||
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
|
||||
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Stream = waker
|
||||
r.HealthMon = waker
|
||||
case entry.UseHealthCheck(r):
|
||||
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
|
||||
}
|
||||
r.task = providerSubtask
|
||||
r.task.OnComplete("stop stream", r.CloseListeners)
|
||||
|
||||
if err := r.Setup(); err != nil {
|
||||
return E.FailWith("setup", err)
|
||||
}
|
||||
r.done = make(chan struct{})
|
||||
r.l.Infof("listening on port %d", r.Port.ListeningPort)
|
||||
|
||||
go r.acceptConnections()
|
||||
if !r.Healthcheck.Disabled {
|
||||
r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck)
|
||||
r.HealthMon.Start()
|
||||
|
||||
if r.HealthMon != nil {
|
||||
r.HealthMon.Start(r.task.Subtask("health monitor"))
|
||||
}
|
||||
streamRoutes.Store(string(r.Alias), r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Stop() E.NestedError {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.task == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
streamRoutes.Delete(string(r.Alias))
|
||||
|
||||
if r.HealthMon != nil {
|
||||
r.HealthMon.Stop()
|
||||
r.HealthMon = nil
|
||||
}
|
||||
|
||||
r.cancel()
|
||||
r.CloseListeners()
|
||||
|
||||
<-r.done
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Started() bool {
|
||||
return r.task != nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) acceptConnections() {
|
||||
var connWg sync.WaitGroup
|
||||
|
||||
task := r.task.Subtask("%s accept connections", r.String())
|
||||
|
||||
defer func() {
|
||||
connWg.Wait()
|
||||
task.Finished()
|
||||
r.task.Finished()
|
||||
r.task, r.cancel = nil, nil
|
||||
close(r.done)
|
||||
r.done = nil
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
case <-r.task.Context().Done():
|
||||
return
|
||||
default:
|
||||
conn, err := r.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
case <-r.task.Context().Done():
|
||||
return
|
||||
default:
|
||||
var nErr *net.OpError
|
||||
var nErr *stdNet.OpError
|
||||
ok := errors.As(err, &nErr)
|
||||
if !(ok && nErr.Timeout()) {
|
||||
r.l.Error(err)
|
||||
r.l.Error("accept connection error: ", err)
|
||||
r.task.Finish(err.Error())
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
connWg.Add(1)
|
||||
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
go func() {
|
||||
err := r.Handle(conn)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
r.l.Error(err)
|
||||
connTask.Finish(err.Error())
|
||||
} else {
|
||||
connTask.Finish("connection closed")
|
||||
}
|
||||
connWg.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
|
@ -21,7 +22,7 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
func NewTCPRoute(base *StreamRoute) StreamImpl {
|
||||
func NewTCPRoute(base *StreamRoute) *TCPRoute {
|
||||
return &TCPRoute{StreamRoute: base}
|
||||
}
|
||||
|
||||
|
@ -36,19 +37,16 @@ func (route *TCPRoute) Setup() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (route *TCPRoute) Accept() (any, error) {
|
||||
func (route *TCPRoute) Accept() (types.StreamConn, error) {
|
||||
route.listener.SetDeadline(time.Now().Add(time.Second))
|
||||
return route.listener.Accept()
|
||||
}
|
||||
|
||||
func (route *TCPRoute) Handle(c any) error {
|
||||
func (route *TCPRoute) Handle(c types.StreamConn) error {
|
||||
clientConn := c.(net.Conn)
|
||||
|
||||
defer clientConn.Close()
|
||||
go func() {
|
||||
<-route.task.Context().Done()
|
||||
clientConn.Close()
|
||||
}()
|
||||
route.task.OnComplete("close conn", func() { clientConn.Close() })
|
||||
|
||||
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
|
||||
|
||||
|
@ -70,5 +68,4 @@ func (route *TCPRoute) CloseListeners() {
|
|||
return
|
||||
}
|
||||
route.listener.Close()
|
||||
route.listener = nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
|
@ -33,7 +35,7 @@ var NewUDPConnMap = F.NewMap[UDPConnMap]
|
|||
|
||||
const udpBufferSize = 8192
|
||||
|
||||
func NewUDPRoute(base *StreamRoute) StreamImpl {
|
||||
func NewUDPRoute(base *StreamRoute) *UDPRoute {
|
||||
return &UDPRoute{
|
||||
StreamRoute: base,
|
||||
connMap: NewUDPConnMap(),
|
||||
|
@ -64,7 +66,7 @@ func (route *UDPRoute) Setup() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (route *UDPRoute) Accept() (any, error) {
|
||||
func (route *UDPRoute) Accept() (types.StreamConn, error) {
|
||||
in := route.listeningConn
|
||||
|
||||
buffer := make([]byte, udpBufferSize)
|
||||
|
@ -104,7 +106,7 @@ func (route *UDPRoute) Accept() (any, error) {
|
|||
return conn, err
|
||||
}
|
||||
|
||||
func (route *UDPRoute) Handle(c any) error {
|
||||
func (route *UDPRoute) Handle(c types.StreamConn) error {
|
||||
conn := c.(*UDPConn)
|
||||
err := conn.Start()
|
||||
route.connMap.Delete(conn.key)
|
||||
|
@ -114,19 +116,25 @@ func (route *UDPRoute) Handle(c any) error {
|
|||
func (route *UDPRoute) CloseListeners() {
|
||||
if route.listeningConn != nil {
|
||||
route.listeningConn.Close()
|
||||
route.listeningConn = nil
|
||||
}
|
||||
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
|
||||
if err := conn.src.Close(); err != nil {
|
||||
route.l.Errorf("error closing src conn: %s", err)
|
||||
}
|
||||
if err := conn.dst.Close(); err != nil {
|
||||
route.l.Error("error closing dst conn: %s", err)
|
||||
if err := conn.Close(); err != nil {
|
||||
route.l.Errorf("error closing conn: %s", err)
|
||||
}
|
||||
})
|
||||
route.connMap.Clear()
|
||||
}
|
||||
|
||||
// Close implements types.StreamConn
|
||||
func (conn *UDPConn) Close() error {
|
||||
return errors.Join(conn.src.Close(), conn.dst.Close())
|
||||
}
|
||||
|
||||
// RemoteAddr implements types.StreamConn
|
||||
func (conn *UDPConn) RemoteAddr() net.Addr {
|
||||
return conn.src.RemoteAddr()
|
||||
}
|
||||
|
||||
type sourceRWCloser struct {
|
||||
server *net.UDPConn
|
||||
*net.UDPConn
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
|
@ -9,8 +10,7 @@ import (
|
|||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"golang.org/x/net/context"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
|
@ -21,7 +21,8 @@ type Server struct {
|
|||
httpStarted bool
|
||||
httpsStarted bool
|
||||
startTime time.Time
|
||||
task common.Task
|
||||
|
||||
task task.Task
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
|
@ -84,7 +85,7 @@ func NewServer(opt Options) (s *Server) {
|
|||
CertProvider: opt.CertProvider,
|
||||
http: httpSer,
|
||||
https: httpsSer,
|
||||
task: common.GlobalTask(opt.Name + " server"),
|
||||
task: task.GlobalTask(opt.Name + " server"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,11 +116,7 @@ func (s *Server) Start() {
|
|||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-s.task.Context().Done()
|
||||
s.stop()
|
||||
s.task.Finished()
|
||||
}()
|
||||
s.task.OnComplete("stop server", s.stop)
|
||||
}
|
||||
|
||||
func (s *Server) stop() {
|
||||
|
@ -127,16 +124,13 @@ func (s *Server) stop() {
|
|||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if s.http != nil && s.httpStarted {
|
||||
s.handleErr("http", s.http.Shutdown(ctx))
|
||||
s.handleErr("http", s.http.Shutdown(s.task.Context()))
|
||||
s.httpStarted = false
|
||||
}
|
||||
|
||||
if s.https != nil && s.httpsStarted {
|
||||
s.handleErr("https", s.https.Shutdown(ctx))
|
||||
s.handleErr("https", s.https.Shutdown(s.task.Context()))
|
||||
s.httpsStarted = false
|
||||
}
|
||||
}
|
||||
|
@ -147,7 +141,7 @@ func (s *Server) Uptime() time.Duration {
|
|||
|
||||
func (s *Server) handleErr(scheme string, err error) {
|
||||
switch {
|
||||
case err == nil, errors.Is(err, http.ErrServerClosed):
|
||||
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
|
||||
return
|
||||
default:
|
||||
logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err)
|
||||
|
|
43
internal/task/dummy_task.go
Normal file
43
internal/task/dummy_task.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package task
|
||||
|
||||
import "context"
|
||||
|
||||
type dummyTask struct{}
|
||||
|
||||
func DummyTask() (_ Task) {
|
||||
return
|
||||
}
|
||||
|
||||
// Context implements Task.
|
||||
func (d dummyTask) Context() context.Context {
|
||||
panic("call of dummyTask.Context")
|
||||
}
|
||||
|
||||
// Finish implements Task.
|
||||
func (d dummyTask) Finish() {}
|
||||
|
||||
// Name implements Task.
|
||||
func (d dummyTask) Name() string {
|
||||
return "Dummy Task"
|
||||
}
|
||||
|
||||
// OnComplete implements Task.
|
||||
func (d dummyTask) OnComplete(about string, fn func()) {
|
||||
panic("call of dummyTask.OnComplete")
|
||||
}
|
||||
|
||||
// Parent implements Task.
|
||||
func (d dummyTask) Parent() Task {
|
||||
panic("call of dummyTask.Parent")
|
||||
}
|
||||
|
||||
// Subtask implements Task.
|
||||
func (d dummyTask) Subtask(usageFmt string, args ...any) Task {
|
||||
panic("call of dummyTask.Subtask")
|
||||
}
|
||||
|
||||
// Wait implements Task.
|
||||
func (d dummyTask) Wait() {}
|
||||
|
||||
// WaitSubTasks implements Task.
|
||||
func (d dummyTask) WaitSubTasks() {}
|
310
internal/task/task.go
Normal file
310
internal/task/task.go
Normal file
|
@ -0,0 +1,310 @@
|
|||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
var globalTask = createGlobalTask()
|
||||
|
||||
func createGlobalTask() (t *task) {
|
||||
t = new(task)
|
||||
t.name = "root"
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
t.subtasks = xsync.NewMapOf[*task, struct{}]()
|
||||
return
|
||||
}
|
||||
|
||||
type (
|
||||
// Task controls objects' lifetime.
|
||||
//
|
||||
// Task must be initialized, use DummyTask if the task is not yet started.
|
||||
//
|
||||
// Objects that uses a task should implement the TaskStarter and the TaskFinisher interface.
|
||||
//
|
||||
// When passing a Task object to another function,
|
||||
// it must be a sub-task of the current task,
|
||||
// in name of "`currentTaskName`Subtask"
|
||||
//
|
||||
// Use Task.Finish to stop all subtasks of the task.
|
||||
Task interface {
|
||||
TaskFinisher
|
||||
// Name returns the name of the task.
|
||||
Name() string
|
||||
// Context returns the context associated with the task. This context is
|
||||
// canceled when Finish is called.
|
||||
Context() context.Context
|
||||
// FinishCause returns the reason / error that caused the task to be finished.
|
||||
FinishCause() error
|
||||
// Parent returns the parent task of the current task.
|
||||
Parent() Task
|
||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||
//
|
||||
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
|
||||
//
|
||||
// This should not be called after Finish, Wait, or WaitSubTasks is called.
|
||||
Subtask(usageFmt string, args ...any) Task
|
||||
// OnComplete calls fn when the task and all subtasks are finished.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
OnComplete(about string, fn func())
|
||||
// Wait waits for all subtasks, itself and all OnComplete to finish.
|
||||
//
|
||||
// It must be called only after Finish is called.
|
||||
Wait()
|
||||
// WaitSubTasks waits for all subtasks of the task to finish.
|
||||
//
|
||||
// No more subtasks can be added after this call.
|
||||
//
|
||||
// It can be called before Finish is called.
|
||||
WaitSubTasks()
|
||||
}
|
||||
TaskStarter interface {
|
||||
// Start starts the object that implements TaskStarter,
|
||||
// and returns an error if it fails to start.
|
||||
//
|
||||
// The task passed must be a subtask of the caller task.
|
||||
//
|
||||
// callerSubtask.Finish must be called when start fails or the object is finished.
|
||||
Start(callerSubtask Task) E.NestedError
|
||||
}
|
||||
TaskFinisher interface {
|
||||
// Finish marks the task as finished by cancelling its context.
|
||||
//
|
||||
// Then call Wait to wait for all subtasks and OnComplete of the task to finish.
|
||||
//
|
||||
// Note that it will also cancel all subtasks.
|
||||
Finish(reason string)
|
||||
}
|
||||
task struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
|
||||
parent *task
|
||||
subtasks *xsync.MapOf[*task, struct{}]
|
||||
|
||||
name, line string
|
||||
|
||||
subTasksWg, onCompleteWg sync.WaitGroup
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProgramExiting = errors.New("program exiting")
|
||||
ErrTaskCancelled = errors.New("task cancelled")
|
||||
)
|
||||
|
||||
// GlobalTask returns a new Task with the given name, derived from the global context.
|
||||
func GlobalTask(format string, args ...any) Task {
|
||||
return globalTask.Subtask(format, args...)
|
||||
}
|
||||
|
||||
// DebugTaskMap returns a map[string]any representation of the global task tree.
|
||||
//
|
||||
// The returned map is suitable for encoding to JSON, and can be used
|
||||
// to debug the task tree.
|
||||
//
|
||||
// The returned map is not guaranteed to be stable, and may change
|
||||
// between runs of the program. It is intended for debugging purposes
|
||||
// only.
|
||||
func DebugTaskMap() map[string]any {
|
||||
return globalTask.serialize()
|
||||
}
|
||||
|
||||
// CancelGlobalContext cancels the global task context, which will cause all tasks
|
||||
// created to be canceled. This should be called before exiting the program
|
||||
// to ensure that all tasks are properly cleaned up.
|
||||
func CancelGlobalContext() {
|
||||
globalTask.cancel(ErrProgramExiting)
|
||||
}
|
||||
|
||||
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GlobalContextWait(timeout time.Duration) {
|
||||
done := make(chan struct{})
|
||||
after := time.After(timeout)
|
||||
go func() {
|
||||
globalTask.Wait()
|
||||
close(done)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-after:
|
||||
logrus.Warn("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *task) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *task) Context() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
func (t *task) FinishCause() error {
|
||||
return context.Cause(t.ctx)
|
||||
}
|
||||
|
||||
func (t *task) Parent() Task {
|
||||
return t.parent
|
||||
}
|
||||
|
||||
func (t *task) OnComplete(about string, fn func()) {
|
||||
t.onCompleteWg.Add(1)
|
||||
var file string
|
||||
var line int
|
||||
if common.IsTrace {
|
||||
_, file, line, _ = runtime.Caller(1)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err)
|
||||
}
|
||||
}()
|
||||
defer t.onCompleteWg.Done()
|
||||
t.subTasksWg.Wait()
|
||||
<-t.ctx.Done()
|
||||
fn()
|
||||
logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about)
|
||||
t.cancel(nil) // ensure resources are released
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *task) Finish(reason string) {
|
||||
t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason))
|
||||
t.Wait()
|
||||
}
|
||||
|
||||
func (t *task) Subtask(format string, args ...any) Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
return t.newSubTask(ctx, cancel, format)
|
||||
}
|
||||
|
||||
func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task {
|
||||
parent := t
|
||||
subtask := &task{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
name: name,
|
||||
parent: parent,
|
||||
subtasks: xsync.NewMapOf[*task, struct{}](),
|
||||
}
|
||||
parent.subTasksWg.Add(1)
|
||||
parent.subtasks.Store(subtask, struct{}{})
|
||||
if common.IsTrace {
|
||||
_, file, line, ok := runtime.Caller(3)
|
||||
if ok {
|
||||
subtask.line = fmt.Sprintf("%s:%d", file, line)
|
||||
}
|
||||
logrus.Tracef("line %s\ntask %q started", subtask.line, name)
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
logrus.Tracef("task %q finished", subtask.Name())
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
parent.subtasks.Delete(subtask)
|
||||
parent.subTasksWg.Done()
|
||||
}()
|
||||
return subtask
|
||||
}
|
||||
|
||||
func (t *task) Wait() {
|
||||
t.subTasksWg.Wait()
|
||||
if t != globalTask {
|
||||
<-t.ctx.Done()
|
||||
}
|
||||
t.onCompleteWg.Wait()
|
||||
}
|
||||
|
||||
func (t *task) WaitSubTasks() {
|
||||
t.subTasksWg.Wait()
|
||||
}
|
||||
|
||||
// tree returns a string representation of the task tree, with the given
|
||||
// prefix prepended to each line. The prefix is used to indent the tree,
|
||||
// and should be a string of spaces or a similar separator.
|
||||
//
|
||||
// The resulting string is suitable for printing to the console, and can be
|
||||
// used to debug the task tree.
|
||||
//
|
||||
// The tree is traversed in a depth-first manner, with each task's name and
|
||||
// line number (if available) printed on a separate line. The line number is
|
||||
// only printed if the task was created with a non-empty line argument.
|
||||
//
|
||||
// The returned string is not guaranteed to be stable, and may change between
|
||||
// runs of the program. It is intended for debugging purposes only.
|
||||
func (t *task) tree(prefix ...string) string {
|
||||
var sb strings.Builder
|
||||
var pre string
|
||||
if len(prefix) > 0 {
|
||||
pre = prefix[0]
|
||||
sb.WriteString(pre + "- ")
|
||||
}
|
||||
if t.line != "" {
|
||||
sb.WriteString("line " + t.line + "\n")
|
||||
}
|
||||
if len(pre) > 0 {
|
||||
sb.WriteString(pre + "- ")
|
||||
}
|
||||
sb.WriteString(t.Name() + "\n")
|
||||
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
|
||||
sb.WriteString(subtask.tree(pre + " "))
|
||||
return true
|
||||
})
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// serialize returns a map[string]any representation of the task tree.
|
||||
//
|
||||
// The map contains the following keys:
|
||||
// - name: the name of the task
|
||||
// - line: the line number of the task, if available
|
||||
// - subtasks: a slice of maps, each representing a subtask
|
||||
//
|
||||
// The subtask maps contain the same keys, recursively.
|
||||
//
|
||||
// The returned map is suitable for encoding to JSON, and can be used
|
||||
// to debug the task tree.
|
||||
//
|
||||
// The returned map is not guaranteed to be stable, and may change
|
||||
// between runs of the program. It is intended for debugging purposes
|
||||
// only.
|
||||
func (t *task) serialize() map[string]any {
|
||||
m := make(map[string]any)
|
||||
m["name"] = t.name
|
||||
if t.line != "" {
|
||||
m["line"] = t.line
|
||||
}
|
||||
if t.subtasks.Size() > 0 {
|
||||
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
|
||||
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
|
||||
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
|
||||
return true
|
||||
})
|
||||
}
|
||||
return m
|
||||
}
|
147
internal/task/task_test.go
Normal file
147
internal/task/task_test.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package task_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/task"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
defer CancelGlobalContext()
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
subTask := rootTask.Subtask("subtask")
|
||||
|
||||
ExpectEqual(t, "root-task", rootTask.Name())
|
||||
ExpectEqual(t, "subtask", subTask.Name())
|
||||
}
|
||||
|
||||
func TestTaskCancellation(t *testing.T) {
|
||||
defer CancelGlobalContext()
|
||||
|
||||
subTaskDone := make(chan struct{})
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
subTask := rootTask.Subtask("subtask")
|
||||
|
||||
go func() {
|
||||
subTask.Wait()
|
||||
close(subTaskDone)
|
||||
}()
|
||||
|
||||
go rootTask.Finish("done")
|
||||
|
||||
select {
|
||||
case <-subTaskDone:
|
||||
err := subTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(subTask.Context())
|
||||
ExpectError(t, ErrTaskCancelled, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalContextCancellation(t *testing.T) {
|
||||
taskDone := make(chan struct{})
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
|
||||
go func() {
|
||||
rootTask.Wait()
|
||||
close(taskDone)
|
||||
}()
|
||||
|
||||
CancelGlobalContext()
|
||||
|
||||
select {
|
||||
case <-taskDone:
|
||||
err := rootTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(rootTask.Context())
|
||||
ExpectError(t, ErrProgramExiting, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnComplete(t *testing.T) {
|
||||
defer CancelGlobalContext()
|
||||
|
||||
task := GlobalTask("test")
|
||||
var value atomic.Int32
|
||||
task.OnComplete("set value", func() {
|
||||
value.Store(1234)
|
||||
})
|
||||
task.Finish("done")
|
||||
ExpectEqual(t, value.Load(), 1234)
|
||||
}
|
||||
|
||||
func TestGlobalContextWait(t *testing.T) {
|
||||
defer CancelGlobalContext()
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
|
||||
finished1, finished2 := false, false
|
||||
|
||||
subTask1 := rootTask.Subtask("subtask1")
|
||||
subTask2 := rootTask.Subtask("subtask2")
|
||||
subTask1.OnComplete("set finished", func() {
|
||||
finished1 = true
|
||||
})
|
||||
subTask2.OnComplete("set finished", func() {
|
||||
finished2 = true
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask1.Finish("done")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask2.Finish("done")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
subTask1.Wait()
|
||||
subTask2.Wait()
|
||||
rootTask.Finish("done")
|
||||
}()
|
||||
|
||||
GlobalContextWait(1 * time.Second)
|
||||
ExpectTrue(t, finished1)
|
||||
ExpectTrue(t, finished2)
|
||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
||||
ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context()))
|
||||
ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context()))
|
||||
}
|
||||
|
||||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
||||
defer CancelGlobalContext()
|
||||
|
||||
rootTask := GlobalTask("root-task")
|
||||
subTask := rootTask.Subtask("subtask")
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
GlobalContextWait(500 * time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("GlobalContextWait should have timed out")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
|
||||
// Ensure clean exit
|
||||
subTask.Finish("exit")
|
||||
}
|
||||
|
||||
func TestGlobalContextCancel(t *testing.T) {
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
package types
|
||||
|
||||
type Config struct {
|
||||
Providers ProxyProviders `json:"providers" yaml:",flow"`
|
||||
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
|
||||
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
|
||||
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
|
||||
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
|
||||
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Providers: ProxyProviders{},
|
||||
TimeoutShutdown: 3,
|
||||
RedirectToHTTPS: false,
|
||||
}
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
package types
|
||||
|
||||
type ProxyProviders struct {
|
||||
Files []string `json:"include" yaml:"include"` // docker, file
|
||||
Docker map[string]string `json:"docker" yaml:"docker"`
|
||||
}
|
|
@ -23,7 +23,7 @@ func IgnoreError[Result any](r Result, _ error) Result {
|
|||
func ExpectNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil && !reflect.ValueOf(err).IsNil() {
|
||||
t.Errorf("expected err=nil, got %s", err.Error())
|
||||
t.Errorf("expected err=nil, got %s", err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func ExpectNoError(t *testing.T, err error) {
|
|||
func ExpectError(t *testing.T, expected error, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Errorf("expected err %s, got %s", expected.Error(), err.Error())
|
||||
t.Errorf("expected err %s, got %s", expected, err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func ExpectError(t *testing.T, expected error, err error) {
|
|||
func ExpectError2(t *testing.T, input any, expected error, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Errorf("%v: expected err %s, got %s", input, expected.Error(), err.Error())
|
||||
t.Errorf("%v: expected err %s, got %s", input, expected, err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,8 +15,9 @@ import (
|
|||
|
||||
type (
|
||||
DockerWatcher struct {
|
||||
host string
|
||||
client D.Client
|
||||
host string
|
||||
client D.Client
|
||||
clientOwned bool
|
||||
logrus.FieldLogger
|
||||
}
|
||||
DockerListOptions = docker_events.ListOptions
|
||||
|
@ -44,10 +45,11 @@ func DockerrFilterContainer(nameOrID string) filters.KeyValuePair {
|
|||
|
||||
func NewDockerWatcher(host string) DockerWatcher {
|
||||
return DockerWatcher{
|
||||
host: host,
|
||||
clientOwned: true,
|
||||
FieldLogger: (logrus.
|
||||
WithField("module", "docker_watcher").
|
||||
WithField("host", host)),
|
||||
host: host,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +74,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
|
|||
defer close(errCh)
|
||||
|
||||
defer func() {
|
||||
if w.client.Connected() {
|
||||
if w.clientOwned && w.client.Connected() {
|
||||
w.client.Close()
|
||||
}
|
||||
}()
|
||||
|
|
91
internal/watcher/events/event_queue.go
Normal file
91
internal/watcher/events/event_queue.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package events
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type (
|
||||
EventQueue struct {
|
||||
task task.Task
|
||||
queue []Event
|
||||
ticker *time.Ticker
|
||||
onFlush OnFlushFunc
|
||||
onError OnErrorFunc
|
||||
}
|
||||
OnFlushFunc = func(flushTask task.Task, events []Event)
|
||||
OnErrorFunc = func(err E.NestedError)
|
||||
)
|
||||
|
||||
const eventQueueCapacity = 10
|
||||
|
||||
// NewEventQueue returns a new EventQueue with the given
|
||||
// queueTask, flushInterval, onFlush and onError.
|
||||
//
|
||||
// The returned EventQueue will start a goroutine to flush events in the queue
|
||||
// when the flushInterval is reached.
|
||||
//
|
||||
// The onFlush function is called when the flushInterval is reached and the queue is not empty,
|
||||
//
|
||||
// The onError function is called when an error received from the errCh,
|
||||
// or panic occurs in the onFlush function. Panic will cause a E.ErrPanicRecv error.
|
||||
//
|
||||
// flushTask.Finish must be called after the flush is done,
|
||||
// but the onFlush function can return earlier (e.g. run in another goroutine).
|
||||
//
|
||||
// If task is cancelled before the flushInterval is reached, the events in queue will be discarded.
|
||||
func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
|
||||
return &EventQueue{
|
||||
task: parent.Subtask("event queue"),
|
||||
queue: make([]Event, 0, eventQueueCapacity),
|
||||
ticker: time.NewTicker(flushInterval),
|
||||
onFlush: onFlush,
|
||||
onError: onError,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) {
|
||||
go func() {
|
||||
defer e.ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-e.task.Context().Done():
|
||||
e.task.Finish(e.task.FinishCause().Error())
|
||||
return
|
||||
case <-e.ticker.C:
|
||||
if len(e.queue) > 0 {
|
||||
flushTask := e.task.Subtask("flush events")
|
||||
queue := e.queue
|
||||
e.queue = make([]Event, 0, eventQueueCapacity)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
e.onError(E.PanicRecv("panic in onFlush %s", err))
|
||||
}
|
||||
}()
|
||||
e.onFlush(flushTask, queue)
|
||||
}()
|
||||
flushTask.Wait()
|
||||
}
|
||||
case event, ok := <-eventCh:
|
||||
e.queue = append(e.queue, event)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
e.onError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait waits for all events to be flushed and the task to finish.
|
||||
//
|
||||
// It is safe to call this method multiple times.
|
||||
func (e *EventQueue) Wait() {
|
||||
e.task.Wait()
|
||||
}
|
|
@ -74,7 +74,7 @@ var actionNameMap = func() (m map[Action]string) {
|
|||
}()
|
||||
|
||||
func (e Event) String() string {
|
||||
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
|
||||
return fmt.Sprintf("%s %s", e.Action, e.ActorName)
|
||||
}
|
||||
|
||||
func (a Action) String() string {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
|
@ -15,10 +14,10 @@ type HTTPHealthMonitor struct {
|
|||
pinger *http.Client
|
||||
}
|
||||
|
||||
func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor {
|
||||
func NewHTTPHealthMonitor(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) *HTTPHealthMonitor {
|
||||
mon := new(HTTPHealthMonitor)
|
||||
mon.monitor = newMonitor(task, url, config, mon.checkHealth)
|
||||
mon.pinger = &http.Client{Timeout: config.Timeout}
|
||||
mon.monitor = newMonitor(url, config, mon.CheckHealth)
|
||||
mon.pinger = &http.Client{Timeout: config.Timeout, Transport: transport}
|
||||
if config.UseGet {
|
||||
mon.method = http.MethodGet
|
||||
} else {
|
||||
|
@ -27,19 +26,26 @@ func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckCo
|
|||
return mon
|
||||
}
|
||||
|
||||
func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) {
|
||||
func NewHTTPHealthChecker(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) HealthChecker {
|
||||
return NewHTTPHealthMonitor(url, config, transport)
|
||||
}
|
||||
|
||||
func (mon *HTTPHealthMonitor) CheckHealth() (healthy bool, detail string, err error) {
|
||||
ctx, cancel := mon.ContextWithTimeout("ping request timed out")
|
||||
defer cancel()
|
||||
|
||||
req, reqErr := http.NewRequestWithContext(
|
||||
mon.task.Context(),
|
||||
ctx,
|
||||
mon.method,
|
||||
mon.url.JoinPath(mon.config.Path).String(),
|
||||
mon.url.Load().JoinPath(mon.config.Path).String(),
|
||||
nil,
|
||||
)
|
||||
if reqErr != nil {
|
||||
err = reqErr
|
||||
return
|
||||
}
|
||||
req.Header.Set("Connection", "close")
|
||||
|
||||
req.Header.Set("Connection", "close")
|
||||
resp, respErr := mon.pinger.Do(req)
|
||||
if respErr == nil {
|
||||
resp.Body.Close()
|
||||
|
|
|
@ -2,78 +2,93 @@ package health
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
HealthMonitor interface {
|
||||
Start()
|
||||
Stop()
|
||||
task.TaskStarter
|
||||
task.TaskFinisher
|
||||
fmt.Stringer
|
||||
json.Marshaler
|
||||
Status() Status
|
||||
Uptime() time.Duration
|
||||
Name() string
|
||||
String() string
|
||||
MarshalJSON() ([]byte, error)
|
||||
}
|
||||
HealthChecker interface {
|
||||
CheckHealth() (healthy bool, detail string, err error)
|
||||
URL() types.URL
|
||||
Config() *HealthCheckConfig
|
||||
UpdateURL(url types.URL)
|
||||
}
|
||||
HealthCheckFunc func() (healthy bool, detail string, err error)
|
||||
monitor struct {
|
||||
service string
|
||||
config *HealthCheckConfig
|
||||
url types.URL
|
||||
url U.AtomicValue[types.URL]
|
||||
|
||||
status U.AtomicValue[Status]
|
||||
checkHealth HealthCheckFunc
|
||||
startTime time.Time
|
||||
|
||||
task common.Task
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
task task.Task
|
||||
}
|
||||
)
|
||||
|
||||
var monMap = F.NewMapOf[string, HealthMonitor]()
|
||||
|
||||
func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
|
||||
service := task.Name()
|
||||
task, cancel := task.SubtaskWithCancel("Health monitor for %s", service)
|
||||
func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
|
||||
mon := &monitor{
|
||||
service: service,
|
||||
config: config,
|
||||
url: url,
|
||||
checkHealth: healthCheckFunc,
|
||||
startTime: time.Now(),
|
||||
task: task,
|
||||
cancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
task: task.DummyTask(),
|
||||
}
|
||||
mon.url.Store(url)
|
||||
mon.status.Store(StatusHealthy)
|
||||
return mon
|
||||
}
|
||||
|
||||
func Inspect(name string) (HealthMonitor, bool) {
|
||||
return monMap.Load(name)
|
||||
func Inspect(service string) (HealthMonitor, bool) {
|
||||
return monMap.Load(service)
|
||||
}
|
||||
|
||||
func (mon *monitor) Start() {
|
||||
defer monMap.Store(mon.task.Name(), mon)
|
||||
func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cancel context.CancelFunc) {
|
||||
if mon.task != nil {
|
||||
return context.WithTimeoutCause(mon.task.Context(), mon.config.Timeout, errors.New(cause))
|
||||
} else {
|
||||
return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause))
|
||||
}
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
|
||||
mon.service = routeSubtask.Parent().Name()
|
||||
mon.task = routeSubtask
|
||||
|
||||
if err := mon.checkUpdateHealth(); err != nil {
|
||||
mon.task.Finish(fmt.Sprintf("healthchecker %s failure: %s", mon.service, err))
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
defer close(mon.done)
|
||||
defer mon.task.Finished()
|
||||
defer func() {
|
||||
monMap.Delete(mon.task.Name())
|
||||
if mon.status.Load() != StatusError {
|
||||
mon.status.Store(StatusUnknown)
|
||||
}
|
||||
mon.task.Finish(mon.task.FinishCause().Error())
|
||||
}()
|
||||
|
||||
ok := mon.checkUpdateHealth()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
monMap.Store(mon.service, mon)
|
||||
|
||||
ticker := time.NewTicker(mon.config.Interval)
|
||||
defer ticker.Stop()
|
||||
|
@ -83,48 +98,61 @@ func (mon *monitor) Start() {
|
|||
case <-mon.task.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
ok = mon.checkUpdateHealth()
|
||||
if !ok {
|
||||
err := mon.checkUpdateHealth()
|
||||
if err != nil {
|
||||
logger.Errorf("healthchecker %s failure: %s", mon.service, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mon *monitor) Stop() {
|
||||
monMap.Delete(mon.task.Name())
|
||||
|
||||
mon.mu.Lock()
|
||||
defer mon.mu.Unlock()
|
||||
|
||||
if mon.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mon.cancel()
|
||||
<-mon.done
|
||||
|
||||
mon.cancel = nil
|
||||
mon.status.Store(StatusUnknown)
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (mon *monitor) Finish(reason string) {
|
||||
mon.task.Finish(reason)
|
||||
}
|
||||
|
||||
// UpdateURL implements HealthChecker.
|
||||
func (mon *monitor) UpdateURL(url types.URL) {
|
||||
mon.url.Store(url)
|
||||
}
|
||||
|
||||
// URL implements HealthChecker.
|
||||
func (mon *monitor) URL() types.URL {
|
||||
return mon.url.Load()
|
||||
}
|
||||
|
||||
// Config implements HealthChecker.
|
||||
func (mon *monitor) Config() *HealthCheckConfig {
|
||||
return mon.config
|
||||
}
|
||||
|
||||
// Status implements HealthMonitor.
|
||||
func (mon *monitor) Status() Status {
|
||||
return mon.status.Load()
|
||||
}
|
||||
|
||||
// Uptime implements HealthMonitor.
|
||||
func (mon *monitor) Uptime() time.Duration {
|
||||
return time.Since(mon.startTime)
|
||||
}
|
||||
|
||||
// Name implements HealthMonitor.
|
||||
func (mon *monitor) Name() string {
|
||||
if mon.task == nil {
|
||||
return ""
|
||||
}
|
||||
return mon.task.Name()
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer of HealthMonitor.
|
||||
func (mon *monitor) String() string {
|
||||
return mon.Name()
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler of HealthMonitor.
|
||||
func (mon *monitor) MarshalJSON() ([]byte, error) {
|
||||
return (&JSONRepresentation{
|
||||
Name: mon.service,
|
||||
|
@ -132,19 +160,19 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
|
|||
Status: mon.status.Load(),
|
||||
Started: mon.startTime,
|
||||
Uptime: mon.Uptime(),
|
||||
URL: mon.url,
|
||||
URL: mon.url.Load(),
|
||||
}).MarshalJSON()
|
||||
}
|
||||
|
||||
func (mon *monitor) checkUpdateHealth() (hasError bool) {
|
||||
func (mon *monitor) checkUpdateHealth() E.NestedError {
|
||||
healthy, detail, err := mon.checkHealth()
|
||||
if err != nil {
|
||||
defer mon.task.Finish(err.Error())
|
||||
mon.status.Store(StatusError)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
logger.Errorf("%s failed to check health: %s", mon.service, err)
|
||||
return E.Failure("check health").With(err)
|
||||
}
|
||||
mon.Stop()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
var status Status
|
||||
if healthy {
|
||||
|
@ -160,5 +188,5 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
|
|||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package health
|
|||
import (
|
||||
"net"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
|
@ -14,9 +13,9 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor {
|
||||
func NewRawHealthMonitor(url types.URL, config *HealthCheckConfig) *RawHealthMonitor {
|
||||
mon := new(RawHealthMonitor)
|
||||
mon.monitor = newMonitor(task, url, config, mon.checkAvail)
|
||||
mon.monitor = newMonitor(url, config, mon.CheckHealth)
|
||||
mon.dialer = &net.Dialer{
|
||||
Timeout: config.Timeout,
|
||||
FallbackDelay: -1,
|
||||
|
@ -24,14 +23,22 @@ func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckCon
|
|||
return mon
|
||||
}
|
||||
|
||||
func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) {
|
||||
conn, dialErr := mon.dialer.DialContext(mon.task.Context(), mon.url.Scheme, mon.url.Host)
|
||||
func NewRawHealthChecker(url types.URL, config *HealthCheckConfig) HealthChecker {
|
||||
return NewRawHealthMonitor(url, config)
|
||||
}
|
||||
|
||||
func (mon *RawHealthMonitor) CheckHealth() (healthy bool, detail string, err error) {
|
||||
ctx, cancel := mon.ContextWithTimeout("ping request timed out")
|
||||
defer cancel()
|
||||
|
||||
url := mon.url.Load()
|
||||
conn, dialErr := mon.dialer.DialContext(ctx, url.Scheme, url.Host)
|
||||
if dialErr != nil {
|
||||
detail = dialErr.Error()
|
||||
/* trunk-ignore(golangci-lint/nilerr) */
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
avail = true
|
||||
healthy = true
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue