Fixed a few issues:

- Incorrect name being shown on dashboard "Proxies page"
- Apps being shown when homepage.show is false
- Load balanced routes are shown on homepage instead of the load balancer
- Route with idlewatcher will now be removed on container destroy
- Idlewatcher panic
- Performance improvement
- Idlewatcher infinitely loading
- Reload stucked / not working properly
- Streams stuck on shutdown / reload
- etc...
Added:
- support idlewatcher for loadbalanced routes
- partial implementation for stream type idlewatcher
Issues:
- graceful shutdown
This commit is contained in:
yusing 2024-10-18 16:47:01 +08:00
parent c0c61709ca
commit 53557e38b6
69 changed files with 2368 additions and 1654 deletions

View file

@ -30,6 +30,12 @@ get:
debug:
make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
debug-trace:
make build && sudo GOPROXY_DEBUG=1 GOPROXY_TRACE=1 bin/go-proxy
profile:
GODEBUG=gctrace=1 make build && sudo GOPROXY_DEBUG=1 bin/go-proxy
mtrace:
bin/go-proxy debug-ls-mtrace > mtrace.json

View file

@ -20,6 +20,7 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg"
)
@ -32,8 +33,14 @@ func main() {
}
l := logrus.WithField("module", "main")
timeFmt := "01-02 15:04:05"
fullTS := true
if common.IsDebug {
if common.IsTrace {
logrus.SetLevel(logrus.TraceLevel)
timeFmt = "04:05"
fullTS = false
} else if common.IsDebug {
logrus.SetLevel(logrus.DebugLevel)
}
@ -42,9 +49,9 @@ func main() {
} else {
logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true,
FullTimestamp: true,
FullTimestamp: fullTS,
ForceColors: true,
TimestampFormat: "01-02 15:04:05",
TimestampFormat: timeFmt,
})
logrus.Infof("go-proxy version %s", pkg.GetVersion())
}
@ -76,21 +83,22 @@ func main() {
middleware.LoadComposeFiles()
if err := config.Load(); err != nil {
var cfg *config.Config
var err E.NestedError
if cfg, err = config.Load(); err != nil {
logrus.Warn(err)
}
cfg := config.GetInstance()
switch args.Command {
case common.CommandListConfigs:
printJSON(cfg.Value())
printJSON(config.Value())
return
case common.CommandListRoutes:
routes, err := query.ListRoutes()
if err != nil {
log.Printf("failed to connect to api server: %s", err)
log.Printf("falling back to config file")
printJSON(cfg.RoutesByAlias())
printJSON(config.RoutesByAlias())
} else {
printJSON(routes)
}
@ -103,10 +111,10 @@ func main() {
printJSON(icons)
return
case common.CommandDebugListEntries:
printJSON(cfg.DumpEntries())
printJSON(config.DumpEntries())
return
case common.CommandDebugListProviders:
printJSON(cfg.DumpProviders())
printJSON(config.DumpProviders())
return
case common.CommandDebugListMTrace:
trace, err := query.ListMiddlewareTraces()
@ -114,17 +122,25 @@ func main() {
log.Fatal(err)
}
printJSON(trace)
return
case common.CommandDebugListTasks:
tasks, err := query.DebugListTasks()
if err != nil {
log.Fatal(err)
}
printJSON(tasks)
return
}
cfg.StartProxyProviders()
cfg.WatchChanges()
config.WatchChanges()
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT)
signal.Notify(sig, syscall.SIGTERM)
signal.Notify(sig, syscall.SIGHUP)
autocert := cfg.GetAutoCertProvider()
autocert := config.GetAutoCertProvider()
if autocert != nil {
if err := autocert.Setup(); err != nil {
l.Fatal(err)
@ -139,14 +155,14 @@ func main() {
HTTPAddr: common.ProxyHTTPAddr,
HTTPSAddr: common.ProxyHTTPSAddr,
Handler: http.HandlerFunc(R.ProxyHandler),
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
RedirectToHTTPS: config.Value().RedirectToHTTPS,
})
apiServer := server.InitAPIServer(server.Options{
Name: "api",
CertProvider: autocert,
HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(cfg),
RedirectToHTTPS: cfg.Value().RedirectToHTTPS,
Handler: api.NewHandler(),
RedirectToHTTPS: config.Value().RedirectToHTTPS,
})
proxyServer.Start()
@ -157,8 +173,8 @@ func main() {
// grafully shutdown
logrus.Info("shutting down")
common.CancelGlobalContext()
common.GlobalContextWait(time.Second * time.Duration(cfg.Value().TimeoutShutdown))
task.CancelGlobalContext()
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
}
func prepareDirectory(dir string) {

View file

@ -2,6 +2,7 @@ package api
import (
"fmt"
"net"
"net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1"
@ -21,34 +22,35 @@ func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc
mux.ServeMux.HandleFunc(fmt.Sprintf("%s %s", method, endpoint), checkHost(handler))
}
func NewHandler(cfg *config.Config) http.Handler {
func NewHandler() http.Handler {
mux := NewServeMux()
mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("GET", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
mux.HandleFunc("HEAD", "/v1/checkhealth", wrap(cfg, v1.CheckHealth))
mux.HandleFunc("POST", "/v1/reload", wrap(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List))
mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List))
mux.HandleFunc("GET", "/v1/checkhealth", v1.CheckHealth)
mux.HandleFunc("HEAD", "/v1/checkhealth", v1.CheckHealth)
mux.HandleFunc("POST", "/v1/reload", v1.Reload)
mux.HandleFunc("GET", "/v1/list", v1.List)
mux.HandleFunc("GET", "/v1/list/{what}", v1.List)
mux.HandleFunc("GET", "/v1/file", v1.GetFileContent)
mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent)
mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS))
mux.HandleFunc("GET", "/v1/stats", v1.Stats)
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
return mux
}
// allow only requests to API server with host matching common.APIHTTPAddr.
// allow only requests to API server with localhost.
func checkHost(f http.HandlerFunc) http.HandlerFunc {
if common.IsDebug {
return f
}
return func(w http.ResponseWriter, r *http.Request) {
if r.Host != common.APIHTTPAddr {
Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr)
http.Error(w, "invalid request", http.StatusForbidden)
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
Logger.Warnf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden)
return
}
f(w, r)

View file

@ -4,11 +4,10 @@ import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/watcher/health"
)
func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
func CheckHealth(w http.ResponseWriter, r *http.Request) {
target := r.FormValue("target")
if target == "" {
HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest)

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/provider"
"github.com/yusing/go-proxy/internal/route/provider"
)
func GetFileContent(w http.ResponseWriter, r *http.Request) {

View file

@ -9,19 +9,21 @@ import (
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
const (
ListRoutes = "routes"
ListConfigFiles = "config_files"
ListMiddlewares = "middlewares"
ListMiddlewareTrace = "middleware_trace"
ListMatchDomains = "match_domains"
ListHomepageConfig = "homepage_config"
ListRoutes = "routes"
ListConfigFiles = "config_files"
ListMiddlewares = "middlewares"
ListMiddlewareTraces = "middleware_trace"
ListMatchDomains = "match_domains"
ListHomepageConfig = "homepage_config"
ListTasks = "tasks"
)
func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
func List(w http.ResponseWriter, r *http.Request) {
what := r.PathValue("what")
if what == "" {
what = ListRoutes
@ -29,27 +31,24 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
switch what {
case ListRoutes:
listRoutes(cfg, w, r)
U.RespondJSON(w, r, config.RoutesByAlias(route.RouteType(r.FormValue("type"))))
case ListConfigFiles:
listConfigFiles(w, r)
case ListMiddlewares:
listMiddlewares(w, r)
case ListMiddlewareTrace:
listMiddlewareTrace(w, r)
U.RespondJSON(w, r, middleware.All())
case ListMiddlewareTraces:
U.RespondJSON(w, r, middleware.GetAllTrace())
case ListMatchDomains:
listMatchDomains(cfg, w, r)
U.RespondJSON(w, r, config.Value().MatchDomains)
case ListHomepageConfig:
listHomepageConfig(cfg, w, r)
U.RespondJSON(w, r, config.HomepageConfig())
case ListTasks:
U.RespondJSON(w, r, task.DebugTaskMap())
default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
}
}
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type")))
U.RespondJSON(w, r, routes)
}
func listConfigFiles(w http.ResponseWriter, r *http.Request) {
files, err := utils.ListFiles(common.ConfigBasePath, 1)
if err != nil {
@ -61,19 +60,3 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) {
}
U.RespondJSON(w, r, files)
}
func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, middleware.GetAllTrace())
}
func listMiddlewares(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, middleware.All())
}
func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, cfg.Value().MatchDomains)
}
func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, cfg.HomepageConfig())
}

View file

@ -34,36 +34,34 @@ func ReloadServer() E.NestedError {
return nil
}
func ListRoutes() (map[string]map[string]any, E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes))
func List[T any](what string) (_ T, outErr E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
if err != nil {
return nil, E.From(err)
outErr = E.From(err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list routes").Extraf("status code: %v", resp.StatusCode)
outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode)
return
}
var routes map[string]map[string]any
err = json.NewDecoder(resp.Body).Decode(&routes)
var res T
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return nil, E.From(err)
outErr = E.From(err)
return
}
return routes, nil
return res, nil
}
func ListRoutes() (map[string]map[string]any, E.NestedError) {
return List[map[string]map[string]any](v1.ListRoutes)
}
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace))
if err != nil {
return nil, E.From(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode)
}
var traces middleware.Traces
err = json.NewDecoder(resp.Body).Decode(&traces)
if err != nil {
return nil, E.From(err)
}
return traces, nil
return List[middleware.Traces](v1.ListMiddlewareTraces)
}
func DebugListTasks() (map[string]any, E.NestedError) {
return List[map[string]any](v1.ListTasks)
}

View file

@ -7,8 +7,8 @@ import (
"github.com/yusing/go-proxy/internal/config"
)
func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
if err := cfg.Reload(); err != nil {
func Reload(w http.ResponseWriter, r *http.Request) {
if err := config.Reload(); err != nil {
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)

View file

@ -14,19 +14,19 @@ import (
"github.com/yusing/go-proxy/internal/utils"
)
func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats(cfg))
func Stats(w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats())
}
func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
func StatsWS(w http.ResponseWriter, r *http.Request) {
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
originPats := make([]string, len(cfg.Value().MatchDomains)+len(localAddresses))
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
if len(originPats) == 0 {
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins")
originPats = []string{"*"}
} else {
for i, domain := range cfg.Value().MatchDomains {
for i, domain := range config.Value().MatchDomains {
originPats[i] = "*." + domain
}
originPats = append(originPats, localAddresses...)
@ -51,7 +51,7 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
defer ticker.Stop()
for range ticker.C {
stats := getStats(cfg)
stats := getStats()
if err := wsjson.Write(ctx, conn, stats); err != nil {
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err)
return
@ -59,9 +59,9 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
}
}
func getStats(cfg *config.Config) map[string]any {
func getStats() map[string]any {
return map[string]any{
"proxies": cfg.Statistics(),
"proxies": config.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()),
}
}

View file

@ -9,7 +9,7 @@ import (
"github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/config/types"
)
type Config types.AutoCertConfig

View file

@ -13,9 +13,9 @@ import (
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
)
@ -140,23 +140,22 @@ func (p *Provider) ScheduleRenewal() {
if p.GetName() == ProviderLocal {
return
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
task := common.NewTask("cert renew scheduler")
defer task.Finished()
for {
select {
case <-task.Context().Done():
return
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
go func() {
task := task.GlobalTask("cert renew scheduler")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
defer task.Finish("cert renew scheduler stopped")
for {
select {
case <-task.Context().Done():
return
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
}
}
}
}
}()
}
func (p *Provider) initClient() E.NestedError {

View file

@ -17,7 +17,7 @@ func (p *Provider) Setup() (err E.NestedError) {
}
}
go p.ScheduleRenewal()
p.ScheduleRenewal()
for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry)

View file

@ -22,6 +22,7 @@ const (
CommandDebugListEntries = "debug-ls-entries"
CommandDebugListProviders = "debug-ls-providers"
CommandDebugListMTrace = "debug-ls-mtrace"
CommandDebugListTasks = "debug-ls-tasks"
)
var ValidCommands = []string{
@ -35,6 +36,7 @@ var ValidCommands = []string{
CommandDebugListEntries,
CommandDebugListProviders,
CommandDebugListMTrace,
CommandDebugListTasks,
}
func GetArgs() Args {

View file

@ -43,7 +43,6 @@ const (
HealthCheckIntervalDefault = 5 * time.Second
HealthCheckTimeoutDefault = 5 * time.Second
IdleTimeoutDefault = "0"
WakeTimeoutDefault = "30s"
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"

View file

@ -15,6 +15,7 @@ var (
NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", true)
IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest)
IsTrace = GetEnvBool("GOPROXY_TRACE", false) && IsDebug
ProxyHTTPAddr,
ProxyHTTPHost,

View file

@ -1,224 +0,0 @@
package common
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/sirupsen/logrus"
)
var (
globalCtx, globalCtxCancel = context.WithCancel(context.Background())
taskWg sync.WaitGroup
tasksMap = xsync.NewMapOf[*task, struct{}]()
)
type (
Task interface {
Name() string
Context() context.Context
Subtask(usageFmt string, args ...interface{}) Task
SubtaskWithCancel(usageFmt string, args ...interface{}) (Task, context.CancelFunc)
Finished()
}
task struct {
ctx context.Context
subtasks []*task
name string
finished bool
mu sync.Mutex
}
)
func (t *task) Name() string {
return t.name
}
// Context returns the context associated with the task. This context is
// canceled when the task is finished.
func (t *task) Context() context.Context {
return t.ctx
}
// Finished marks the task as finished and notifies the global wait group.
// Finished is thread-safe and idempotent.
func (t *task) Finished() {
t.mu.Lock()
defer t.mu.Unlock()
if t.finished {
return
}
t.finished = true
if _, ok := tasksMap.Load(t); ok {
taskWg.Done()
tasksMap.Delete(t)
}
logrus.Debugf("task %q finished", t.Name())
}
// Subtask returns a new subtask with the given name, derived from the receiver's context.
//
// The returned subtask is associated with the receiver's context and will be
// automatically registered and deregistered from the global task wait group.
//
// If the receiver's context is already canceled, the returned subtask will be
// canceled immediately.
//
// The returned subtask is safe for concurrent use.
func (t *task) Subtask(format string, args ...interface{}) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
t.mu.Lock()
defer t.mu.Unlock()
sub := newSubTask(t.ctx, format)
t.subtasks = append(t.subtasks, sub)
return sub
}
// SubtaskWithCancel returns a new subtask with the given name, derived from the receiver's context,
// and a cancel function. The returned subtask is associated with the receiver's context and will be
// automatically registered and deregistered from the global task wait group.
//
// If the receiver's context is already canceled, the returned subtask will be canceled immediately.
//
// The returned cancel function is safe for concurrent use, and can be used to cancel the returned
// subtask at any time.
func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
t.mu.Lock()
defer t.mu.Unlock()
ctx, cancel := context.WithCancel(t.ctx)
sub := newSubTask(ctx, format)
t.subtasks = append(t.subtasks, sub)
return sub, cancel
}
func (t *task) tree(prefix ...string) string {
var sb strings.Builder
var pre string
if len(prefix) > 0 {
pre = prefix[0]
}
sb.WriteString(pre)
sb.WriteString(t.Name() + "\n")
for _, sub := range t.subtasks {
if sub.finished {
continue
}
sb.WriteString(sub.tree(pre + " "))
}
return sb.String()
}
func newSubTask(ctx context.Context, name string) *task {
t := &task{
ctx: ctx,
name: name,
}
tasksMap.Store(t, struct{}{})
taskWg.Add(1)
logrus.Debugf("task %q started", name)
return t
}
// NewTask returns a new Task with the given name, derived from the global
// context.
//
// The returned Task is associated with the global context and will be
// automatically registered and deregistered from the global context's wait
// group.
//
// If the global context is already canceled, the returned Task will be
// canceled immediately.
//
// The returned Task is not safe for concurrent use.
func NewTask(format string, args ...interface{}) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return newSubTask(globalCtx, format)
}
// NewTaskWithCancel returns a new Task with the given name, derived from the
// global context, and a cancel function. The returned Task is associated with
// the global context and will be automatically registered and deregistered
// from the global task wait group.
//
// If the global context is already canceled, the returned Task will be
// canceled immediately.
//
// The returned Task is safe for concurrent use.
//
// The returned cancel function is safe for concurrent use, and can be used
// to cancel the returned Task at any time.
func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
subCtx, cancel := context.WithCancel(globalCtx)
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return newSubTask(subCtx, format), cancel
}
// GlobalTask returns a new Task with the given name, associated with the
// global context.
//
// Unlike NewTask, GlobalTask does not automatically register or deregister
// the Task with the global task wait group. The returned Task is not
// started, but the name is formatted immediately.
//
// This is best used for main task that do not need to wait and
// will create a bunch of subtasks.
//
// The returned Task is safe for concurrent use.
func GlobalTask(format string, args ...interface{}) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return &task{
ctx: globalCtx,
name: format,
}
}
// CancelGlobalContext cancels the global context, which will cause all tasks
// created by GlobalTask or NewTask to be canceled. This should be called
// before exiting the program to ensure that all tasks are properly cleaned
// up.
func CancelGlobalContext() {
globalCtxCancel()
}
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) {
done := make(chan struct{})
after := time.After(timeout)
go func() {
taskWg.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
logrus.Warnln("Timeout waiting for these tasks to finish:")
tasksMap.Range(func(t *task, _ struct{}) bool {
logrus.Warnln(t.tree())
return true
})
return
}
}
}

View file

@ -2,51 +2,66 @@ package config
import (
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
"gopkg.in/yaml.v3"
)
type Config struct {
value *types.Config
proxyProviders F.Map[string, *PR.Provider]
providers F.Map[string, *proxy.Provider]
autocertProvider *autocert.Provider
l logrus.FieldLogger
watcher W.Watcher
reloadReq chan struct{}
task task.Task
}
var instance *Config
var (
instance *Config
cfgWatcher watcher.Watcher
logger = logrus.WithField("module", "config")
reloadMu sync.Mutex
)
const configEventFlushInterval = 500 * time.Millisecond
const (
cfgRenameWarn = `Config file renamed, not reloading.
Make sure you rename it back before next time you start.`
cfgDeleteWarn = `Config file deleted, not reloading.
You may run "ls-config" to show or dump the current config.`
)
func GetInstance() *Config {
return instance
}
func Load() E.NestedError {
func newConfig() *Config {
return &Config{
value: types.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](),
task: task.GlobalTask("config"),
}
}
func Load() (*Config, E.NestedError) {
if instance != nil {
return nil
return instance, nil
}
instance = &Config{
value: types.DefaultConfig(),
proxyProviders: F.NewMapOf[string, *PR.Provider](),
l: logrus.WithField("module", "config"),
watcher: W.NewConfigFileWatcher(common.ConfigFileName),
reloadReq: make(chan struct{}, 1),
}
return instance.load()
instance = newConfig()
cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
return instance, instance.load()
}
func Validate(data []byte) E.NestedError {
@ -54,87 +69,90 @@ func Validate(data []byte) E.NestedError {
}
func MatchDomains() []string {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return instance.value.MatchDomains
}
func (cfg *Config) Value() types.Config {
if cfg == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
}
return *cfg.value
func WatchChanges() {
task := task.GlobalTask("Config watcher")
eventQueue := events.NewEventQueue(
task,
configEventFlushInterval,
OnConfigChange,
func(err E.NestedError) {
logger.Error(err)
},
)
eventQueue.Start(cfgWatcher.Events(task.Context()))
}
func (cfg *Config) GetAutoCertProvider() *autocert.Provider {
if instance == nil {
logrus.Panic("config has not been loaded, please check if there is any errors")
func OnConfigChange(flushTask task.Task, ev []events.Event) {
defer flushTask.Finish("config reload complete")
// no matter how many events during the interval
// just reload once and check the last event
switch ev[len(ev)-1].Action {
case events.ActionFileRenamed:
logger.Warn(cfgRenameWarn)
return
case events.ActionFileDeleted:
logger.Warn(cfgDeleteWarn)
return
}
if err := Reload(); err != nil {
logger.Error(err)
}
return cfg.autocertProvider
}
func (cfg *Config) Reload() (err E.NestedError) {
cfg.stopProviders()
err = cfg.load()
cfg.StartProxyProviders()
return
func Reload() E.NestedError {
// avoid race between config change and API reload request
reloadMu.Lock()
defer reloadMu.Unlock()
newCfg := newConfig()
err := newCfg.load()
if err != nil {
return err
}
// cancel all current subtasks -> wait
// -> replace config -> start new subtasks
instance.task.Finish("config changed")
instance.task.Wait()
*instance = *newCfg
instance.StartProxyProviders()
return nil
}
func Value() types.Config {
return *instance.value
}
func GetAutoCertProvider() *autocert.Provider {
return instance.autocertProvider
}
func (cfg *Config) Task() task.Task {
return cfg.task
}
func (cfg *Config) StartProxyProviders() {
cfg.controlProviders("start", (*PR.Provider).StartAllRoutes)
}
func (cfg *Config) WatchChanges() {
task := common.NewTask("Config watcher")
go func() {
defer task.Finished()
for {
select {
case <-task.Context().Done():
return
case <-cfg.reloadReq:
if err := cfg.Reload(); err != nil {
cfg.l.Error(err)
}
}
}
}()
go func() {
eventCh, errCh := cfg.watcher.Events(task.Context())
for {
select {
case <-task.Context().Done():
return
case event := <-eventCh:
if event.Action == events.ActionFileDeleted || event.Action == events.ActionFileRenamed {
cfg.l.Error("config file deleted or renamed, ignoring...")
continue
} else {
cfg.reloadReq <- struct{}{}
}
case err := <-errCh:
cfg.l.Error(err)
continue
}
}
}()
}
func (cfg *Config) forEachRoute(do func(alias string, r *R.Route, p *PR.Provider)) {
cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) {
p.RangeRoutes(func(a string, r *R.Route) {
do(a, r, p)
})
b := E.NewBuilder("errors starting providers")
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName())))
})
if b.HasError() {
logger.Error(b.Build())
}
}
func (cfg *Config) load() (res E.NestedError) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
cfg.l.Debug("loading config")
defer cfg.l.Debug("loaded config")
logger.Debug("loading config")
defer logger.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err != nil {
@ -160,7 +178,7 @@ func (cfg *Config) load() (res E.NestedError) {
b.Add(cfg.loadProviders(&model.Providers))
cfg.value = model
R.SetFindMuxDomains(model.MatchDomains)
route.SetFindMuxDomains(model.MatchDomains)
return
}
@ -169,8 +187,8 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
return
}
cfg.l.Debug("initializing autocert")
defer cfg.l.Debug("initialized autocert")
logger.Debug("initializing autocert")
defer logger.Debug("initialized autocert")
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err != nil {
@ -179,48 +197,34 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
return
}
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (res E.NestedError) {
cfg.l.Debug("loading providers")
defer cfg.l.Debug("loaded providers")
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) {
subtask := cfg.task.Subtask("load providers")
defer subtask.Finish("done")
b := E.NewBuilder("errors loading providers")
defer b.To(&res)
errs := E.NewBuilder("errors loading providers")
results := E.NewBuilder("loaded providers")
defer errs.To(&outErr)
for _, filename := range providers.Files {
p, err := PR.NewFileProvider(filename)
p, err := proxy.NewFileProvider(filename)
if err != nil {
b.Add(err.Subject(filename))
errs.Add(err)
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(filename))
cfg.providers.Store(p.GetName(), p)
errs.Add(p.LoadRoutes().Subject(filename))
results.Addf("%d routes from %s", p.NumRoutes(), filename)
}
for name, dockerHost := range providers.Docker {
p, err := PR.NewDockerProvider(name, dockerHost)
p, err := proxy.NewDockerProvider(name, dockerHost)
if err != nil {
b.Add(err.Subjectf("%s (%s)", name, dockerHost))
errs.Add(err.Subjectf("%s (%s)", name, dockerHost))
continue
}
cfg.proxyProviders.Store(p.GetName(), p)
b.Add(p.LoadRoutes().Subject(p.GetName()))
cfg.providers.Store(p.GetName(), p)
errs.Add(p.LoadRoutes().Subject(p.GetName()))
results.Addf("%d routes from %s", p.NumRoutes(), name)
}
logger.Info(results.Build())
return
}
func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) {
errors := E.NewBuilder("errors in %s these providers", action)
cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) {
if err := do(p); err != nil {
errors.Add(err.Subject(p))
}
})
if err := errors.Build(); err != nil {
cfg.l.Error(err)
}
}
func (cfg *Config) stopProviders() {
cfg.controlProviders("stop routes", (*PR.Provider).StopAllRoutes)
}

View file

@ -6,33 +6,35 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/homepage"
PR "github.com/yusing/go-proxy/internal/proxy/provider"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
func (cfg *Config) DumpEntries() map[string]*types.RawEntry {
entries := make(map[string]*types.RawEntry)
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) {
entries[alias] = r.Entry
func DumpEntries() map[string]*entry.RawEntry {
entries := make(map[string]*entry.RawEntry)
instance.providers.RangeAll(func(_ string, p *proxy.Provider) {
p.RangeRoutes(func(alias string, r *route.Route) {
entries[alias] = r.Entry
})
})
return entries
}
func (cfg *Config) DumpProviders() map[string]*PR.Provider {
entries := make(map[string]*PR.Provider)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
func DumpProviders() map[string]*proxy.Provider {
entries := make(map[string]*proxy.Provider)
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
entries[name] = p
})
return entries
}
func (cfg *Config) HomepageConfig() homepage.Config {
func HomepageConfig() homepage.Config {
var proto, port string
domains := cfg.value.MatchDomains
cert, _ := cfg.autocertProvider.GetCert(nil)
domains := instance.value.MatchDomains
cert, _ := instance.autocertProvider.GetCert(nil)
if cert != nil {
proto = "https"
port = common.ProxyHTTPSPort
@ -42,9 +44,9 @@ func (cfg *Config) HomepageConfig() homepage.Config {
}
hpCfg := homepage.NewHomePageConfig()
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
entry := r.Raw
item := entry.Homepage
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
en := r.Raw
item := en.Homepage
if item == nil {
item = new(homepage.Item)
item.Show = true
@ -63,12 +65,12 @@ func (cfg *Config) HomepageConfig() homepage.Config {
)
}
if r.IsDocker() {
if entry.IsDocker(r) {
if item.Category == "" {
item.Category = "Docker"
}
item.SourceType = string(PR.ProviderTypeDocker)
} else if r.UseLoadBalance() {
item.SourceType = string(proxy.ProviderTypeDocker)
} else if entry.UseLoadBalance(r) {
if item.Category == "" {
item.Category = "Load-balanced"
}
@ -77,7 +79,7 @@ func (cfg *Config) HomepageConfig() homepage.Config {
if item.Category == "" {
item.Category = "Others"
}
item.SourceType = string(PR.ProviderTypeFile)
item.SourceType = string(proxy.ProviderTypeFile)
}
if item.URL == "" {
@ -85,26 +87,26 @@ func (cfg *Config) HomepageConfig() homepage.Config {
item.URL = fmt.Sprintf("%s://%s.%s:%s", proto, strings.ToLower(alias), domains[0], port)
}
}
item.AltURL = r.URL().String()
item.AltURL = r.TargetURL().String()
hpCfg.Add(item)
})
return hpCfg
}
func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
func RoutesByAlias(typeFilter ...route.RouteType) map[string]any {
routes := make(map[string]any)
if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream}
typeFilter = []route.RouteType{route.RouteTypeReverseProxy, route.RouteTypeStream}
}
for _, t := range typeFilter {
switch t {
case R.RouteTypeReverseProxy:
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
case route.RouteTypeReverseProxy:
route.GetReverseProxies().RangeAll(func(alias string, r *route.HTTPRoute) {
routes[alias] = r
})
case R.RouteTypeStream:
R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) {
case route.RouteTypeStream:
route.GetStreamProxies().RangeAll(func(alias string, r *route.StreamRoute) {
routes[alias] = r
})
}
@ -112,12 +114,12 @@ func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
return routes
}
func (cfg *Config) Statistics() map[string]any {
func Statistics() map[string]any {
nTotalStreams := 0
nTotalRPs := 0
providerStats := make(map[string]PR.ProviderStats)
providerStats := make(map[string]proxy.ProviderStats)
cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) {
instance.providers.RangeAll(func(name string, p *proxy.Provider) {
providerStats[name] = p.Statistics()
})
@ -133,9 +135,9 @@ func (cfg *Config) Statistics() map[string]any {
}
}
func (cfg *Config) FindRoute(alias string) *R.Route {
return F.MapFind(cfg.proxyProviders,
func(p *PR.Provider) (*R.Route, bool) {
func FindRoute(alias string) *route.Route {
return F.MapFind(instance.providers,
func(p *proxy.Provider) (*route.Route, bool) {
if route, ok := p.GetRoute(alias); ok {
return route, true
}

View file

@ -0,0 +1,24 @@
package types
type (
Config struct {
Providers ProxyProviders `json:"providers" yaml:",flow"`
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
}
ProxyProviders struct {
Files []string `json:"include" yaml:"include"` // docker, file
Docker map[string]string `json:"docker" yaml:"docker"`
}
)
func DefaultConfig() *Config {
return &Config{
Providers: ProxyProviders{},
TimeoutShutdown: 3,
RedirectToHTTPS: false,
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@ -36,22 +37,13 @@ var (
)
func init() {
go func() {
task := common.NewTask("close all docker client")
defer task.Finished()
for {
select {
case <-task.Context().Done():
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()
}
})
clientMap.Clear()
return
task.GlobalTask("close docker clients").OnComplete("", func() {
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()
}
}
}()
})
})
}
func (c *SharedClient) Connected() bool {
@ -141,19 +133,10 @@ func ConnectClient(host string) (Client, E.NestedError) {
<-c.refCount.Zero()
clientMap.Delete(c.key)
if c.Client != nil {
if c.Connected() {
c.Client.Close()
c.Client = nil
c.l.Debugf("client closed")
}
}()
return c, nil
}
func CloseAllClients() {
clientMap.RangeAllParallel(func(_ string, c Client) {
c.Client.Close()
})
clientMap.Clear()
logger.Debug("closed all clients")
}

View file

@ -2,6 +2,7 @@ package docker
import (
"context"
"errors"
"time"
"github.com/docker/docker/api/types"
@ -16,10 +17,13 @@ type ClientInfo struct {
}
var listOptions = container.ListOptions{
// created|restarting|running|removing|paused|exited|dead
// Filters: filters.NewArgs(
// filters.Arg("health", "healthy"),
// filters.Arg("health", "none"),
// filters.Arg("health", "starting"),
// filters.Arg("status", "created"),
// filters.Arg("status", "restarting"),
// filters.Arg("status", "running"),
// filters.Arg("status", "paused"),
// filters.Arg("status", "exited"),
// ),
All: true,
}
@ -31,7 +35,7 @@ func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedE
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout"))
defer cancel()
var containers []types.Container

View file

@ -32,11 +32,11 @@ type (
IsExcluded bool `json:"is_excluded" yaml:"-"`
IsExplicit bool `json:"is_explicit" yaml:"-"`
IsDatabase bool `json:"is_database" yaml:"-"`
IdleTimeout string `json:"idle_timeout" yaml:"-"`
WakeTimeout string `json:"wake_timeout" yaml:"-"`
StopMethod string `json:"stop_method" yaml:"-"`
StopTimeout string `json:"stop_timeout" yaml:"-"` // stop_method = "stop" only
StopSignal string `json:"stop_signal" yaml:"-"` // stop_method = "stop" | "kill" only
IdleTimeout string `json:"idle_timeout,omitempty" yaml:"-"`
WakeTimeout string `json:"wake_timeout,omitempty" yaml:"-"`
StopMethod string `json:"stop_method,omitempty" yaml:"-"`
StopTimeout string `json:"stop_timeout,omitempty" yaml:"-"` // stop_method = "stop" only
StopSignal string `json:"stop_signal,omitempty" yaml:"-"` // stop_method = "stop" | "kill" only
Running bool `json:"running" yaml:"-"`
}
)

View file

@ -0,0 +1,112 @@
package idlewatcher
import (
"time"
"github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
type (
Config struct {
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"` // docker api takes integer seconds for timeout argument
StopMethod StopMethod `json:"stop_method,omitempty"`
StopSignal Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running,omitempty"`
}
StopMethod string
Signal string
)
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
if cont == nil {
return nil, nil
}
if cont.IdleTimeout == "" {
return &Config{
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}, nil
}
b := E.NewBuilder("invalid idlewatcher config")
defer b.To(&res)
idleTimeout, err := validateDurationPostitive(cont.IdleTimeout)
b.Add(err.Subjectf("%s", "idle_timeout"))
wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout)
b.Add(err.Subjectf("%s", "wake_timeout"))
stopTimeout, err := validateDurationPostitive(cont.StopTimeout)
b.Add(err.Subjectf("%s", "stop_timeout"))
stopMethod, err := validateStopMethod(cont.StopMethod)
b.Add(err)
signal, err := validateSignal(cont.StopSignal)
b.Add(err)
if err := b.Build(); err != nil {
return
}
return &Config{
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopTimeout: int(stopTimeout.Seconds()),
StopMethod: stopMethod,
StopSignal: signal,
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}, nil
}
func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value).With(err)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}
func validateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}
func validateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View file

@ -20,16 +20,15 @@ var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(lo
const headerCheckRedirect = "X-Goproxy-Check-Redirect"
func (w *Watcher) makeRespBody(format string, args ...any) []byte {
msg := fmt.Sprintf(format, args...)
func (w *Watcher) makeLoadingPageBody() []byte {
msg := fmt.Sprintf("%s is starting...", w.ContainerName)
data := new(templateData)
data.CheckRedirectHeader = headerCheckRedirect
data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, "\n", "<br>")
data.Message = strings.ReplaceAll(data.Message, " ", "&ensp;")
data.Message = strings.ReplaceAll(msg, " ", "&ensp;")
buf := bytes.NewBuffer(make([]byte, 128)) // more than enough
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect)))
err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen in production
panic(err)

View file

@ -1,197 +1,133 @@
package idlewatcher
import (
"context"
"net/http"
"strconv"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Waker struct {
*Watcher
type Waker interface {
health.HealthMonitor
http.Handler
net.Stream
}
type waker struct {
_ U.NoCopy
client *http.Client
rp *gphttp.ReverseProxy
stream net.Stream
hc health.HealthChecker
ready atomic.Bool
}
func NewWaker(w *Watcher, rp *gphttp.ReverseProxy) *Waker {
return &Waker{
Watcher: w,
client: &http.Client{
Timeout: 1 * time.Second,
Transport: rp.Transport,
},
rp: rp,
const (
idleWakerCheckInterval = 100 * time.Millisecond
idleWakerCheckTimeout = time.Second
)
// TODO: support stream
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) {
hcCfg := entry.HealthCheckConfig()
hcCfg.Timeout = idleWakerCheckTimeout
waker := &waker{
rp: rp,
stream: stream,
}
}
func (w *Waker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
shouldNext := w.wake(rw, r)
if !shouldNext {
return
watcher, err := registerWatcher(providerSubTask, entry, waker)
if err != nil {
return nil, err
}
w.rp.ServeHTTP(rw, r)
if rp != nil {
waker.hc = health.NewHTTPHealthChecker(entry.TargetURL(), hcCfg, rp.Transport)
} else if stream != nil {
waker.hc = health.NewRawHealthChecker(entry.TargetURL(), hcCfg)
} else {
panic("both nil")
}
return watcher, nil
}
/* HealthMonitor interface */
func (w *Waker) Start() {}
func (w *Waker) Stop() {
w.Unregister()
// lifetime should follow route provider
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) {
return newWaker(providerSubTask, entry, rp, nil)
}
func (w *Waker) UpdateConfig(config health.HealthCheckConfig) {
panic("use idlewatcher.Register instead")
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) {
return newWaker(providerSubTask, entry, nil, stream)
}
func (w *Waker) Name() string {
// Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask task.Task) E.NestedError {
w.task.OnComplete("stop route", func() {
routeSubTask.Parent().Finish("watcher stopped")
})
return nil
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason string) {}
// Name implements health.HealthMonitor.
func (w *Watcher) Name() string {
return w.String()
}
func (w *Waker) String() string {
return string(w.Alias)
// String implements health.HealthMonitor.
func (w *Watcher) String() string {
return w.ContainerName
}
func (w *Waker) Status() health.Status {
if w.ready.Load() {
return health.StatusHealthy
}
if !w.ContainerRunning {
return health.StatusNapping
}
return health.StatusStarting
}
func (w *Waker) Uptime() time.Duration {
// Uptime implements health.HealthMonitor.
func (w *Watcher) Uptime() time.Duration {
return 0
}
func (w *Waker) MarshalJSON() ([]byte, error) {
var url types.URL
if w.URL.String() != "http://:0" {
url = w.URL
// Status implements health.HealthMonitor.
func (w *Watcher) Status() health.Status {
if !w.ContainerRunning {
return health.StatusNapping
}
if w.ready.Load() {
return health.StatusHealthy
}
healthy, _, err := w.hc.CheckHealth()
switch {
case err != nil:
return health.StatusError
case healthy:
w.ready.Store(true)
return health.StatusHealthy
default:
return health.StatusStarting
}
}
// MarshalJSON implements health.HealthMonitor.
func (w *Watcher) MarshalJSON() ([]byte, error) {
var url net.URL
if w.hc.URL().Port() != "0" {
url = w.hc.URL()
}
return (&health.JSONRepresentation{
Name: w.Name(),
Status: w.Status(),
Config: &health.HealthCheckConfig{
Interval: w.IdleTimeout,
Timeout: w.WakeTimeout,
},
URL: url,
Config: w.hc.Config(),
URL: url,
}).MarshalJSON()
}
/* End of HealthMonitor interface */
func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is ready
if w.ready.Load() {
return true
}
ctx, cancel := context.WithTimeout(r.Context(), w.WakeTimeout)
defer cancel()
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeRespBody("%s waking up...", w.ContainerName)
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
rw.Header().Add("Cache-Control", "no-cache")
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Cache-Control", "must-revalidate")
if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err)
}
return
}
select {
case <-w.task.Context().Done():
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
w.l.Debug("wake signal received")
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return
}
// maybe another request came in while we were waiting for the wake
if w.ready.Load() {
if isCheckRedirect {
rw.WriteHeader(http.StatusOK)
return
}
return true
}
for {
select {
case <-w.task.Context().Done():
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
wakeReq, err := http.NewRequestWithContext(
ctx,
http.MethodHead,
w.URL.String(),
nil,
)
if err != nil {
w.l.Errorf("new request err to %s: %s", r.URL, err)
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
wakeResp, err := w.client.Do(wakeReq)
if err == nil && wakeResp.StatusCode != http.StatusServiceUnavailable {
w.ready.Store(true)
w.l.Debug("awaken")
if isCheckRedirect {
rw.WriteHeader(http.StatusOK)
return
}
logrus.Infof("container %s is ready, passing through to %s", w.Alias, w.rp.TargetURL)
return true
}
// retry until the container is ready or timeout
time.Sleep(100 * time.Millisecond)
}
}
// static HealthMonitor interface check
func (w *Waker) _() health.HealthMonitor {
return w
}

View file

@ -0,0 +1,105 @@
package idlewatcher
import (
"context"
"errors"
"net/http"
"strconv"
"time"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// ServeHTTP implements http.Handler
func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
shouldNext := w.wakeFromHTTP(rw, r)
if !shouldNext {
return
}
w.rp.ServeHTTP(rw, r)
}
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is already ready
if w.ready.Load() {
return true
}
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeLoadingPageBody()
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Content-Length", strconv.Itoa(len(body)))
rw.Header().Add("Cache-Control", "no-cache")
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Cache-Control", "must-revalidate")
rw.Header().Add("Connection", "close")
if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err)
}
return
}
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
checkCancelled := func() bool {
select {
case <-w.task.Context().Done():
w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context()))
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
case <-ctx.Done():
w.l.Debugf("wake cancelled: %s", context.Cause(ctx))
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return true
default:
return false
}
}
if checkCancelled() {
return false
}
w.l.Debug("wake signal received")
err := w.wakeIfStopped()
if err != nil {
w.l.Error(E.FailWith("wake", err))
http.Error(rw, "Error waking container", http.StatusInternalServerError)
return
}
for {
if checkCancelled() {
return false
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
if isCheckRedirect {
logrus.Debugf("container %s is ready, redirecting...", w.String())
rw.WriteHeader(http.StatusOK)
return
}
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
return true
}
// retry until the container is ready or timeout
time.Sleep(idleWakerCheckInterval)
}
}

View file

@ -0,0 +1,87 @@
package idlewatcher
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// Setup implements types.Stream.
func (w *Watcher) Setup() error {
return w.stream.Setup()
}
// Accept implements types.Stream.
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
conn, err = w.stream.Accept()
// timeout means no connection is accepted
var nErr *net.OpError
ok := errors.As(err, &nErr)
if ok && nErr.Timeout() {
return
}
if err := w.wakeFromStream(); err != nil {
return nil, err
}
return w.stream.Accept()
}
// CloseListeners implements types.Stream.
func (w *Watcher) CloseListeners() {
w.stream.CloseListeners()
}
// Handle implements types.Stream.
func (w *Watcher) Handle(conn types.StreamConn) error {
if err := w.wakeFromStream(); err != nil {
return err
}
return w.stream.Handle(conn)
}
func (w *Watcher) wakeFromStream() error {
// pass through if container is already ready
if w.ready.Load() {
return nil
}
w.l.Debug("wake signal received")
wakeErr := w.wakeIfStopped()
if wakeErr != nil {
wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr)
w.l.Error(wakeErr)
return wakeErr
}
ctx, cancel := context.WithTimeoutCause(w.task.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
for {
select {
case <-w.task.Context().Done():
cause := w.task.FinishCause()
w.l.Debugf("wake cancelled: %s", cause)
return cause
case <-ctx.Done():
cause := context.Cause(ctx)
w.l.Debugf("wake cancelled: %s", cause)
return cause
default:
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL())
return nil
}
// retry until the container is ready or timeout
time.Sleep(idleWakerCheckInterval)
}
}

View file

@ -2,191 +2,193 @@ package idlewatcher
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
type (
Watcher struct {
*P.ReverseProxyEntry
_ U.NoCopy
client D.Client
*idlewatcher.Config
*waker
ready atomic.Bool // whether the site is ready to accept connection
client D.Client
stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
ticker *time.Ticker
task common.Task
cancel context.CancelFunc
refCount *U.RefCount
l logrus.FieldLogger
ticker *time.Ticker
task task.Task
l *logrus.Entry
}
WakeDone <-chan error
WakeFunc func() WakeDone
StopCallback func() E.NestedError
StopCallback func() error
)
var (
watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex
portHistoryMap = F.NewMapOf[PT.Alias, string]()
logger = logrus.WithField("module", "idle_watcher")
)
func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
const dockerReqTimeout = 3 * time.Second
if entry.IdleTimeout == 0 {
return nil, failure.With(E.Invalid("idle_timeout", 0))
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) {
failure := E.Failure("idle_watcher register")
cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 {
panic("should not reach here")
}
watcherMapMu.Lock()
defer watcherMapMu.Unlock()
key := entry.ContainerID
if entry.URL.Port() != "0" {
portHistoryMap.Store(entry.Alias, entry.URL.Port())
}
key := cfg.ContainerID
if w, ok := watcherMap.Load(key); ok {
w.refCount.Add()
w.ReverseProxyEntry = entry
w.Config = cfg
w.waker = waker
w.resetIdleTimer()
return w, nil
}
client, err := D.ConnectClient(entry.DockerHost)
client, err := D.ConnectClient(cfg.DockerHost)
if err.HasError() {
return nil, failure.With(err)
}
w := &Watcher{
ReverseProxyEntry: entry,
client: client,
refCount: U.NewRefCounter(),
ticker: time.NewTicker(entry.IdleTimeout),
l: logger.WithField("container", entry.ContainerName),
Config: cfg,
waker: waker,
client: client,
task: providerSubtask,
ticker: time.NewTicker(cfg.IdleTimeout),
l: logger.WithField("container", cfg.ContainerName),
}
w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias)
w.stopByMethod = w.getStopCallback()
watcherMap.Store(key, w)
go w.watchUntilCancel()
go func() {
cause := w.watchUntilDestroy()
watcherMapMu.Lock()
watcherMap.Delete(w.ContainerID)
watcherMapMu.Unlock()
w.ticker.Stop()
w.client.Close()
w.task.Finish(cause.Error())
}()
return w, nil
}
func (w *Watcher) Unregister() {
w.refCount.Sub()
}
func (w *Watcher) containerStop() error {
return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{
func (w *Watcher) containerStop(ctx context.Context) error {
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout,
})
}
func (w *Watcher) containerPause() error {
return w.client.ContainerPause(w.task.Context(), w.ContainerID)
func (w *Watcher) containerPause(ctx context.Context) error {
return w.client.ContainerPause(ctx, w.ContainerID)
}
func (w *Watcher) containerKill() error {
return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal))
func (w *Watcher) containerKill(ctx context.Context) error {
return w.client.ContainerKill(ctx, w.ContainerID, string(w.StopSignal))
}
func (w *Watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.task.Context(), w.ContainerID)
func (w *Watcher) containerUnpause(ctx context.Context) error {
return w.client.ContainerUnpause(ctx, w.ContainerID)
}
func (w *Watcher) containerStart() error {
return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{})
func (w *Watcher) containerStart(ctx context.Context) error {
return w.client.ContainerStart(ctx, w.ContainerID, container.StartOptions{})
}
func (w *Watcher) containerStatus() (string, E.NestedError) {
func (w *Watcher) containerStatus() (string, error) {
if !w.client.Connected() {
return "", E.Failure("docker client closed")
return "", errors.New("docker client not connected")
}
json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID)
ctx, cancel := context.WithTimeoutCause(w.task.Context(), dockerReqTimeout, errors.New("docker request timeout"))
defer cancel()
json, err := w.client.ContainerInspect(ctx, w.ContainerID)
if err != nil {
return "", E.FailWith("inspect container", err)
return "", fmt.Errorf("failed to inspect container: %w", err)
}
return json.State.Status, nil
}
func (w *Watcher) wakeIfStopped() E.NestedError {
if w.ready.Load() || w.ContainerRunning {
func (w *Watcher) wakeIfStopped() error {
if w.ContainerRunning {
return nil
}
status, err := w.containerStatus()
if err.HasError() {
if err != nil {
return err
}
// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
defer cancel()
// !Hard coded here since theres no constants from Docker API
switch status {
case "exited", "dead":
return E.From(w.containerStart())
return w.containerStart(ctx)
case "paused":
return E.From(w.containerUnpause())
return w.containerUnpause(ctx)
case "running":
return nil
default:
return E.Unexpected("container state", status)
panic("should not reach here")
}
}
func (w *Watcher) getStopCallback() StopCallback {
var cb func() error
var cb func(context.Context) error
switch w.StopMethod {
case PT.StopMethodPause:
case idlewatcher.StopMethodPause:
cb = w.containerPause
case PT.StopMethodStop:
case idlewatcher.StopMethodStop:
cb = w.containerStop
case PT.StopMethodKill:
case idlewatcher.StopMethodKill:
cb = w.containerKill
default:
panic("should not reach here")
}
return func() E.NestedError {
status, err := w.containerStatus()
if err.HasError() {
return err
}
if status != "running" {
return nil
}
return E.From(cb())
return func() error {
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
defer cancel()
return cb(ctx)
}
}
func (w *Watcher) resetIdleTimer() {
w.l.Trace("reset idle timer")
w.ticker.Reset(w.IdleTimeout)
}
func (w *Watcher) watchUntilCancel() {
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) {
eventTask = w.task.Subtask("watcher for %s", w.ContainerID)
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID),
@ -194,34 +196,47 @@ func (w *Watcher) watchUntilCancel() {
W.DockerFilterStop,
W.DockerFilterDie,
W.DockerFilterKill,
W.DockerFilterDestroy,
W.DockerFilterPause,
W.DockerFilterUnpause,
),
})
return
}
defer func() {
w.cancel()
w.ticker.Stop()
w.client.Close()
watcherMap.Delete(w.ContainerID)
w.task.Finished()
}()
// watchUntilDestroy waits for the container to be created, started, or unpaused,
// and then reset the idle timer.
//
// When the container is stopped, paused,
// or killed, the idle timer is stopped and the ContainerRunning flag is set to false.
//
// When the idle timer fires, the container is stopped according to the
// stop method.
//
// it exits only if the context is canceled, the container is destroyed,
// errors occured on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() error {
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
for {
select {
case <-w.task.Context().Done():
w.l.Debug("stopped by context done")
return
case <-w.refCount.Zero():
w.l.Debug("stopped by zero ref count")
return
cause := context.Cause(w.task.Context())
w.l.Debugf("watcher stopped by context done: %s", cause)
return cause
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err))
return
return err.Error()
}
case e := <-dockerEventCh:
switch {
case e.Action == events.ActionContainerDestroy:
w.ContainerRunning = false
w.ready.Store(false)
w.l.Info("watcher stopped by container destruction")
return errors.New("container destroyed")
// create / start / unpause
case e.Action.IsContainerWake():
w.ContainerRunning = true
@ -229,18 +244,31 @@ func (w *Watcher) watchUntilCancel() {
w.l.Info("container awaken")
case e.Action.IsContainerSleep(): // stop / pause / kil
w.ContainerRunning = false
w.ticker.Stop()
w.ready.Store(false)
w.ticker.Stop()
default:
w.l.Errorf("unexpected docker event: %s", e)
}
// container name changed should also change the container id
if w.ContainerName != e.ActorName {
w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName)
w.ContainerName = e.ActorName
}
if w.ContainerID != e.ActorID {
w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID)
w.ContainerID = e.ActorID
// recreate event stream
eventTask.Finish("recreate event stream")
eventTask, dockerEventCh, dockerEventErrCh = w.getEventCh(dockerWatcher)
}
case <-w.ticker.C:
w.l.Debug("idle timeout")
w.ticker.Stop()
if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
} else {
w.l.Info("stopped by idle timeout")
if w.ContainerRunning {
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err)
} else {
w.l.Info("container stopped by idle timeout")
}
}
}
}

View file

@ -2,6 +2,7 @@ package docker
import (
"context"
"errors"
"time"
E "github.com/yusing/go-proxy/internal/error"
@ -19,7 +20,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError)
}
func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel()
json, err := c.ContainerInspect(ctx, containerID)

View file

@ -17,35 +17,36 @@ type builder struct {
}
func NewBuilder(format string, args ...any) Builder {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
if len(args) > 0 {
return Builder{&builder{message: fmt.Sprintf(format, args...)}}
}
return Builder{&builder{message: format}}
}
// adding nil / nil is no-op,
// you may safely pass expressions returning error to it.
func (b Builder) Add(err NestedError) Builder {
func (b Builder) Add(err NestedError) {
if err != nil {
b.Lock()
b.errors = append(b.errors, err)
b.Unlock()
}
return b
}
func (b Builder) AddE(err error) Builder {
return b.Add(From(err))
func (b Builder) AddE(err error) {
b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) Builder {
return b.Add(errorf(format, args...))
func (b Builder) Addf(format string, args ...any) {
b.Add(errorf(format, args...))
}
func (b Builder) AddRangeE(errs ...error) Builder {
func (b Builder) AddRangeE(errs ...error) {
b.Lock()
defer b.Unlock()
for _, err := range errs {
b.AddE(err)
}
return b
}
// Build builds a NestedError based on the errors collected in the Builder.

View file

@ -2,6 +2,7 @@ package error
import (
stderrors "errors"
"fmt"
"reflect"
)
@ -16,6 +17,7 @@ var (
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
ErrPanicRecv = stderrors.New("panic")
)
const fmtSubjectWhat = "%w %v: %q"
@ -75,3 +77,7 @@ func TypeError2(subject any, from, to reflect.Value) NestedError {
func TypeMismatch[Expect any](value any) NestedError {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}
func PanicRecv(format string, args ...any) NestedError {
return errorf("%w%s", ErrPanicRecv, fmt.Sprintf(format, args...))
}

View file

@ -4,18 +4,20 @@ import (
"hash/fnv"
"net"
"net/http"
"sync"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware
pool servers
mu sync.Mutex
}
func (lb *LoadBalancer) newIPHash() impl {
impl := &ipHash{LoadBalancer: lb}
impl := new(ipHash)
if len(lb.Options) == 0 {
return impl
}
@ -26,10 +28,37 @@ func (lb *LoadBalancer) newIPHash() impl {
}
return impl
}
func (ipHash) OnAddServer(srv *Server) {}
func (ipHash) OnRemoveServer(srv *Server) {}
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
func (impl *ipHash) OnAddServer(srv *Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
return
}
if s == nil {
impl.pool[i] = srv
return
}
}
impl.pool = append(impl.pool, srv)
}
func (impl *ipHash) OnRemoveServer(srv *Server) {
impl.mu.Lock()
defer impl.mu.Unlock()
for i, s := range impl.pool {
if s == srv {
impl.pool[i] = nil
return
}
}
}
func (impl *ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
if impl.realIP != nil {
impl.realIP.ModifyRequest(impl.serveHTTP, rw, r)
} else {
@ -37,7 +66,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request)
}
}
func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError)
@ -45,10 +74,12 @@ func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
return
}
idx := hashIP(ip) % uint32(len(impl.pool))
if impl.pool[idx].Status().Bad() {
srv := impl.pool[idx]
if srv == nil || srv.Status().Bad() {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
}
impl.pool[idx].ServeHTTP(rw, r)
srv.ServeHTTP(rw, r)
}
func hashIP(ip string) uint32 {

View file

@ -5,8 +5,9 @@ import (
"sync"
"time"
"github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -28,7 +29,9 @@ type (
impl
*Config
pool servers
task task.Task
pool Pool
poolMu sync.Mutex
sumWeight weightType
@ -41,11 +44,35 @@ type (
const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)}
lb := &LoadBalancer{
Config: new(Config),
pool: newPool(),
task: task.DummyTask(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
lb.startTime = time.Now()
lb.task = routeSubtask
lb.task.OnComplete("loadbalancer cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v)
})
}
lb.pool.Clear()
})
return nil
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason string) {
lb.task.Finish(reason)
}
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case Unset, RoundRobin:
@ -57,9 +84,9 @@ func (lb *LoadBalancer) updateImpl() {
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
for _, srv := range lb.pool {
lb.pool.RangeAll(func(_ string, srv *Server) {
lb.impl.OnAddServer(srv)
}
})
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
@ -91,55 +118,60 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = append(lb.pool, srv)
if lb.pool.Has(srv.Name) {
old, _ := lb.pool.Load(srv.Name)
lb.sumWeight -= old.Weight
lb.impl.OnRemoveServer(old)
}
lb.pool.Store(srv.Name, srv)
lb.sumWeight += srv.Weight
lb.Rebalance()
lb.rebalance()
lb.impl.OnAddServer(srv)
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool))
logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.sumWeight -= srv.Weight
lb.Rebalance()
lb.impl.OnRemoveServer(srv)
for i, s := range lb.pool {
if s == srv {
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
break
}
}
if lb.IsEmpty() {
lb.pool = nil
if !lb.pool.Has(srv.Name) {
return
}
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool))
lb.pool.Delete(srv.Name)
lb.sumWeight -= srv.Weight
lb.rebalance()
lb.impl.OnRemoveServer(srv)
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
logger.Infof("[remove] loadbalancer %s stopped", lb.Link)
return
}
logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
}
func (lb *LoadBalancer) IsEmpty() bool {
return len(lb.pool) == 0
}
func (lb *LoadBalancer) Rebalance() {
func (lb *LoadBalancer) rebalance() {
if lb.sumWeight == maxWeight {
return
}
if lb.pool.Size() == 0 {
return
}
if lb.sumWeight == 0 { // distribute evenly
weightEach := maxWeight / weightType(len(lb.pool))
remainder := maxWeight % weightType(len(lb.pool))
for _, s := range lb.pool {
weightEach := maxWeight / weightType(lb.pool.Size())
remainder := maxWeight % weightType(lb.pool.Size())
lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightEach
lb.sumWeight += weightEach
if remainder > 0 {
s.Weight++
remainder--
}
}
})
return
}
@ -147,18 +179,18 @@ func (lb *LoadBalancer) Rebalance() {
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
lb.sumWeight = 0
for _, s := range lb.pool {
lb.pool.RangeAll(func(_ string, s *Server) {
s.Weight = weightType(float64(s.Weight) * scaleFactor)
lb.sumWeight += s.Weight
}
})
delta := maxWeight - lb.sumWeight
if delta == 0 {
return
}
for _, s := range lb.pool {
lb.pool.Range(func(_ string, s *Server) bool {
if delta == 0 {
break
return false
}
if delta > 0 {
s.Weight++
@ -169,7 +201,8 @@ func (lb *LoadBalancer) Rebalance() {
lb.sumWeight--
delta++
}
}
return true
})
}
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
@ -181,23 +214,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
lb.impl.ServeHTTP(srvs, rw, r)
}
func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet")
}
lb.startTime = time.Now()
logger.Debugf("loadbalancer %s started", lb.Link)
}
func (lb *LoadBalancer) Stop() {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.pool = nil
logger.Debugf("loadbalancer %s stopped", lb.Link)
}
func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
@ -205,9 +221,10 @@ func (lb *LoadBalancer) Uptime() time.Duration {
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
for _, v := range lb.pool {
lb.pool.RangeAll(func(k string, v *Server) {
extra[v.Name] = v.healthMon
}
})
return (&health.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
@ -227,7 +244,7 @@ func (lb *LoadBalancer) Name() string {
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status {
if len(lb.pool) == 0 {
if lb.pool.Size() == 0 {
return health.StatusUnknown
}
if len(lb.availServers()) == 0 {
@ -241,21 +258,13 @@ func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
avail := make(servers, 0, len(lb.pool))
for _, s := range lb.pool {
if s.Status().Bad() {
continue
func (lb *LoadBalancer) availServers() []*Server {
avail := make([]*Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv *Server) {
if srv.Status().Bad() {
return
}
avail = append(avail, s)
}
avail = append(avail, srv)
})
return avail
}
// static HealthMonitor interface check
func (lb *LoadBalancer) _() health.HealthMonitor {
return lb
}

View file

@ -13,7 +13,7 @@ func TestRebalance(t *testing.T) {
for range 10 {
lb.AddServer(&Server{})
}
lb.Rebalance()
lb.rebalance()
ExpectEqual(t, lb.sumWeight, maxWeight)
})
t.Run("less", func(t *testing.T) {
@ -23,7 +23,7 @@ func TestRebalance(t *testing.T) {
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance()
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})
@ -36,7 +36,7 @@ func TestRebalance(t *testing.T) {
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
lb.Rebalance()
lb.rebalance()
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
ExpectEqual(t, lb.sumWeight, maxWeight)
})

View file

@ -6,6 +6,7 @@ import (
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -20,9 +21,12 @@ type (
handler http.Handler
healthMon health.HealthMonitor
}
servers []*Server
servers = []*Server
Pool = F.Map[string, *Server]
)
var newPool = F.NewMap[Pool]
func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server {
srv := &Server{
Name: name,

View file

@ -48,11 +48,11 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m.name = fmt.Sprintf("%s[%d]", name, i)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
continue
}
m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m)
}
if chainErr.HasError() {

View file

@ -0,0 +1,19 @@
package types
import (
"fmt"
"net"
)
type Stream interface {
fmt.Stringer
Setup() error
Accept() (conn StreamConn, err error)
Handle(conn StreamConn) error
CloseListeners()
}
type StreamConn interface {
RemoteAddr() net.Addr
Close() error
}

View file

@ -1,177 +0,0 @@
package proxy
import (
"fmt"
"net/url"
"time"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
ReverseProxyEntry struct { // real model after validation
Raw *types.RawEntry `json:"raw"`
Alias T.Alias `json:"alias,omitempty"`
Scheme T.Scheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns T.PathPatterns `json:"path_patterns,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
Middlewares D.NestedLabelMap `json:"middlewares,omitempty"`
/* Docker only */
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopMethod T.StopMethod `json:"stop_method,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"`
StopSignal T.Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running,omitempty"`
}
StreamEntry struct {
Raw *types.RawEntry `json:"raw"`
Alias T.Alias `json:"alias,omitempty"`
Scheme T.StreamScheme `json:"scheme,omitempty"`
Host T.Host `json:"host,omitempty"`
Port T.StreamPort `json:"port,omitempty"`
Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
}
)
func (rp *ReverseProxyEntry) UseIdleWatcher() bool {
return rp.IdleTimeout > 0 && rp.IsDocker()
}
func (rp *ReverseProxyEntry) UseLoadBalance() bool {
return rp.LoadBalance != nil && rp.LoadBalance.Link != ""
}
func (rp *ReverseProxyEntry) IsDocker() bool {
return rp.DockerHost != ""
}
func (rp *ReverseProxyEntry) IsZeroPort() bool {
return rp.URL.Port() == "0"
}
func (rp *ReverseProxyEntry) ShouldNotServe() bool {
return rp.IsZeroPort() && !rp.UseIdleWatcher()
}
func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme)
if err != nil {
return nil, err
}
var entry any
e := E.NewBuilder("error validating entry")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err != nil {
return nil, err
}
return entry, nil
}
func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
cont := m.Container
if cont == nil {
cont = D.DummyContainer
}
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(cont.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(cont.StopSignal)
b.Add(err)
if err != nil {
return nil
}
return &ReverseProxyEntry{
Raw: m,
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
HealthCheck: &m.HealthCheck,
LoadBalance: &m.LoadBalance,
Middlewares: m.Middlewares,
IdleTimeout: idleTimeout,
WakeTimeout: wakeTimeout,
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}
}
func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry {
host, err := T.ValidateHost(m.Host)
b.Add(err)
port, err := T.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := T.ValidateStreamScheme(m.Scheme)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Raw: m,
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,
Port: port,
Healthcheck: &m.HealthCheck,
}
}

View file

@ -0,0 +1,68 @@
package entry
import (
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Entry interface {
TargetName() string
TargetURL() net.URL
RawEntry() *RawEntry
LoadBalanceConfig() *loadbalancer.Config
HealthCheckConfig() *health.HealthCheckConfig
IdlewatcherConfig() *idlewatcher.Config
}
func ValidateEntry(m *RawEntry) (Entry, E.NestedError) {
m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme)
if err != nil {
return nil, err
}
var entry Entry
e := E.NewBuilder("error validating entry")
if scheme.IsStream() {
entry = validateStreamEntry(m, e)
} else {
entry = validateRPEntry(m, scheme, e)
}
if err := e.Build(); err != nil {
return nil, err
}
return entry, nil
}
func IsDocker(entry Entry) bool {
iw := entry.IdlewatcherConfig()
return iw != nil && iw.ContainerID != ""
}
func IsZeroPort(entry Entry) bool {
return entry.TargetURL().Port() == "0"
}
func ShouldNotServe(entry Entry) bool {
return IsZeroPort(entry) && !UseIdleWatcher(entry)
}
func UseLoadBalance(entry Entry) bool {
lb := entry.LoadBalanceConfig()
return lb != nil && lb.Link != ""
}
func UseIdleWatcher(entry Entry) bool {
iw := entry.IdlewatcherConfig()
return iw != nil && iw.IdleTimeout > 0
}
func UseHealthCheck(entry Entry) bool {
hc := entry.HealthCheckConfig()
return hc != nil && !hc.Disabled
}

View file

@ -1,4 +1,4 @@
package types
package entry
import (
"strconv"
@ -21,16 +21,16 @@ type (
// raw entry object before validation
// loaded from docker labels or yaml file
Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme,omitempty" yaml:"scheme"`
Host string `json:"host,omitempty" yaml:"host"`
Port string `json:"port,omitempty" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
LoadBalance loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"`
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme,omitempty" yaml:"scheme"`
Host string `json:"host,omitempty" yaml:"host"`
Port string `json:"port,omitempty" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"`
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"`
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
/* Docker only */
Container *docker.Container `json:"container,omitempty" yaml:"-"`
@ -122,29 +122,41 @@ func (e *RawEntry) FillMissingFields() {
}
}
if e.HealthCheck.Interval == 0 {
e.HealthCheck.Interval = common.HealthCheckIntervalDefault
if e.HealthCheck == nil {
e.HealthCheck = new(health.HealthCheckConfig)
}
if e.HealthCheck.Timeout == 0 {
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
if e.HealthCheck.Disabled {
e.HealthCheck = nil
} else {
if e.HealthCheck.Interval == 0 {
e.HealthCheck.Interval = common.HealthCheckIntervalDefault
}
if e.HealthCheck.Timeout == 0 {
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
}
}
if cont.IdleTimeout == "" {
cont.IdleTimeout = common.IdleTimeoutDefault
}
if cont.WakeTimeout == "" {
cont.WakeTimeout = common.WakeTimeoutDefault
}
if cont.StopTimeout == "" {
cont.StopTimeout = common.StopTimeoutDefault
}
if cont.StopMethod == "" {
cont.StopMethod = common.StopMethodDefault
if cont.IdleTimeout != "" {
if cont.WakeTimeout == "" {
cont.WakeTimeout = common.WakeTimeoutDefault
}
if cont.StopTimeout == "" {
cont.StopTimeout = common.StopTimeoutDefault
}
if cont.StopMethod == "" {
cont.StopMethod = common.StopMethodDefault
}
}
e.Port = joinPorts(lp, pp, extra)
if e.Port == "" || e.Host == "" {
e.Port = "0"
if lp != "" {
e.Port = lp + ":0"
} else {
e.Port = "0"
}
}
}

View file

@ -0,0 +1,98 @@
package entry
import (
"fmt"
"net/url"
"github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type ReverseProxyEntry struct { // real model after validation
Raw *RawEntry `json:"raw"`
Alias fields.Alias `json:"alias,omitempty"`
Scheme fields.Scheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns fields.PathPatterns `json:"path_patterns,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
Middlewares docker.NestedLabelMap `json:"middlewares,omitempty"`
/* Docker only */
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
}
func (rp *ReverseProxyEntry) TargetName() string {
return string(rp.Alias)
}
func (rp *ReverseProxyEntry) TargetURL() net.URL {
return rp.URL
}
func (rp *ReverseProxyEntry) RawEntry() *RawEntry {
return rp.Raw
}
func (rp *ReverseProxyEntry) LoadBalanceConfig() *loadbalancer.Config {
return rp.LoadBalance
}
func (rp *ReverseProxyEntry) HealthCheckConfig() *health.HealthCheckConfig {
return rp.HealthCheck
}
func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config {
return rp.Idlewatcher
}
func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry {
cont := m.Container
if cont == nil {
cont = docker.DummyContainer
}
lb := m.LoadBalance
if lb != nil && lb.Link == "" {
lb = nil
}
host, err := fields.ValidateHost(m.Host)
b.Add(err)
port, err := fields.ValidatePort(m.Port)
b.Add(err)
pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if err != nil {
return nil
}
return &ReverseProxyEntry{
Raw: m,
Alias: fields.NewAlias(m.Alias),
Scheme: s,
URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns,
HealthCheck: m.HealthCheck,
LoadBalance: lb,
Middlewares: m.Middlewares,
Idlewatcher: idleWatcherCfg,
}
}

View file

@ -0,0 +1,89 @@
package entry
import (
"fmt"
"github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type StreamEntry struct {
Raw *RawEntry `json:"raw"`
Alias fields.Alias `json:"alias,omitempty"`
Scheme fields.StreamScheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
Host fields.Host `json:"host,omitempty"`
Port fields.StreamPort `json:"port,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
/* Docker only */
Idlewatcher *idlewatcher.Config `json:"idlewatcher,omitempty"`
}
func (s *StreamEntry) TargetName() string {
return string(s.Alias)
}
func (s *StreamEntry) TargetURL() net.URL {
return s.URL
}
func (s *StreamEntry) RawEntry() *RawEntry {
return s.Raw
}
func (s *StreamEntry) LoadBalanceConfig() *loadbalancer.Config {
// TODO: support stream load balance
return nil
}
func (s *StreamEntry) HealthCheckConfig() *health.HealthCheckConfig {
return s.HealthCheck
}
func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config {
return s.Idlewatcher
}
func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry {
cont := m.Container
if cont == nil {
cont = docker.DummyContainer
}
host, err := fields.ValidateHost(m.Host)
b.Add(err)
port, err := fields.ValidateStreamPort(m.Port)
b.Add(err)
scheme, err := fields.ValidateStreamScheme(m.Scheme)
b.Add(err)
url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if b.HasError() {
return nil
}
return &StreamEntry{
Raw: m,
Alias: fields.NewAlias(m.Alias),
Scheme: *scheme,
URL: url,
Host: host,
Port: port,
HealthCheck: m.HealthCheck,
Idlewatcher: idleWatcherCfg,
}
}

View file

@ -1,17 +0,0 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type Signal string
func ValidateSignal(s string) (Signal, E.NestedError) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
return Signal(s), nil
}
return "", E.Invalid("signal", s)
}

View file

@ -1,23 +0,0 @@
package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type StopMethod string
const (
StopMethodPause StopMethod = "pause"
StopMethodStop StopMethod = "stop"
StopMethodKill StopMethod = "kill"
)
func ValidateStopMethod(s string) (StopMethod, E.NestedError) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil
default:
return "", E.Invalid("stop_method", sm)
}
}

View file

@ -1,18 +0,0 @@
package fields
import (
"time"
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value)
}
if d < 0 {
return 0, E.Invalid("duration", "negative value")
}
return d, nil
}

View file

@ -9,23 +9,21 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
"github.com/yusing/go-proxy/internal/proxy/entry"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/task"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type (
HTTPRoute struct {
*P.ReverseProxyEntry
*entry.ReverseProxyEntry
HealthMon health.HealthMonitor `json:"health,omitempty"`
@ -33,6 +31,8 @@ type (
server *loadbalancer.Server
handler http.Handler
rp *gphttp.ReverseProxy
task task.Task
}
SubdomainKey = PT.Alias
@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
}
}
func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
var trans *http.Transport
if entry.NoTLSVerify {
@ -84,12 +84,10 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) {
}
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
r := &HTTPRoute{
ReverseProxyEntry: entry,
rp: rp,
task: task.DummyTask(),
}
return r, nil
}
@ -98,39 +96,34 @@ func (r *HTTPRoute) String() string {
return string(r.Alias)
}
func (r *HTTPRoute) URL() url.URL {
return r.ReverseProxyEntry.URL
}
func (r *HTTPRoute) Start() E.NestedError {
if r.ShouldNotServe() {
// Start implements task.TaskStarter.
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.handler != nil {
return nil
}
if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) {
if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.HealthCheck.Disabled = true
}
switch {
case r.UseIdleWatcher():
watcher, err := idlewatcher.Register(r.ReverseProxyEntry)
case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewHTTPWaker(wakerTask, r.ReverseProxyEntry, r.rp)
if err != nil {
return err
}
waker := idlewatcher.NewWaker(watcher, r.rp)
r.handler = waker
r.HealthMon = waker
case !r.HealthCheck.Disabled:
r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask(r.String()), r.URL(), r.HealthCheck)
case entry.UseHealthCheck(r):
r.HealthMon = health.NewHTTPHealthMonitor(r.TargetURL(), r.HealthCheck, r.rp.Transport)
}
r.task = providerSubtask
if r.handler == nil {
switch {
@ -146,44 +139,26 @@ func (r *HTTPRoute) Start() E.NestedError {
}
if r.HealthMon != nil {
r.HealthMon.Start()
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn(E.FailWith("health monitor", err))
}
}
if r.UseLoadBalance() {
if entry.UseLoadBalance(r) {
r.addToLoadBalancer()
} else {
httpRoutes.Store(string(r.Alias), r)
r.task.OnComplete("stop rp", func() {
httpRoutes.Delete(string(r.Alias))
})
}
return nil
}
func (r *HTTPRoute) Stop() (_ E.NestedError) {
if r.handler == nil {
return
}
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.loadBalancer != nil {
r.removeFromLoadBalancer()
} else {
httpRoutes.Delete(string(r.Alias))
}
if r.HealthMon != nil {
r.HealthMon.Stop()
r.HealthMon = nil
}
r.handler = nil
return
}
func (r *HTTPRoute) Started() bool {
return r.handler != nil
// Finish implements task.TaskFinisher.
func (r *HTTPRoute) Finish(reason string) {
r.task.Finish(reason)
}
func (r *HTTPRoute) addToLoadBalancer() {
@ -197,10 +172,14 @@ func (r *HTTPRoute) addToLoadBalancer() {
}
} else {
lb = loadbalancer.New(r.LoadBalance)
lb.Start()
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
lbTask.OnComplete("remove lb from routes", func() {
httpRoutes.Delete(r.LoadBalance.Link)
})
lb.Start(lbTask)
linked = &HTTPRoute{
ReverseProxyEntry: &P.ReverseProxyEntry{
Raw: &types.RawEntry{
ReverseProxyEntry: &entry.ReverseProxyEntry{
Raw: &entry.RawEntry{
Homepage: r.Raw.Homepage,
},
Alias: PT.Alias(lb.Link),
@ -214,16 +193,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server)
}
func (r *HTTPRoute) removeFromLoadBalancer() {
r.loadBalancer.RemoveServer(r.server)
if r.loadBalancer.IsEmpty() {
httpRoutes.Delete(r.LoadBalance.Link)
logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link)
}
r.server = nil
r.loadBalancer = nil
r.task.OnComplete("remove server from lb", func() {
lb.RemoveServer(r.server)
})
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {

View file

@ -10,10 +10,9 @@ import (
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
type DockerProvider struct {
@ -43,7 +42,7 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
routes = R.NewRoutes()
entries := types.NewProxyEntries()
entries := entry.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true)
if err != nil {
@ -66,12 +65,12 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
// there may be some valid entries in `en`
dups := entries.MergeFrom(newEntries)
// add the duplicate proxy entries to the error
dups.RangeAll(func(k string, v *types.RawEntry) {
dups.RangeAll(func(k string, v *entry.RawEntry) {
errors.Addf("duplicate alias %s", k)
})
}
entries.RangeAll(func(_ string, e *types.RawEntry) {
entries.RangeAll(func(_ string, e *entry.RawEntry) {
e.Container.DockerHost = p.dockerHost
})
@ -88,85 +87,10 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
strings.HasSuffix(container.ContainerName, "-old")
}
func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
matches := R.NewRoutes()
oldRoutes.RangeAllParallel(func(k string, v *R.Route) {
if v.Entry.Container.ContainerID == event.ActorID ||
v.Entry.Container.ContainerName == event.ActorName {
matches.Store(k, v)
}
})
//FIXME: docker event die stuck
var newRoutes R.Routes
var err E.NestedError
switch {
// id & container name changed
case matches.Size() == 0:
matches = oldRoutes
newRoutes, err = p.LoadRoutesImpl()
b.Add(err)
case event.Action == events.ActionContainerDestroy:
// stop all old routes
matches.RangeAllParallel(func(_ string, v *R.Route) {
oldRoutes.Delete(v.Entry.Alias)
b.Add(v.Stop())
res.nRemoved++
})
return
default:
cont, err := D.Inspect(p.dockerHost, event.ActorID)
if err != nil {
b.Add(E.FailWith("inspect container", err))
return
}
if p.shouldIgnore(cont) {
// stop all old routes
matches.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
res.nRemoved++
})
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
newRoutes, err = R.FromEntries(entries)
b.Add(err)
}
matches.RangeAll(func(k string, v *R.Route) {
if !newRoutes.Has(k) && !oldRoutes.Has(k) {
b.Add(v.Stop())
matches.Delete(k)
res.nRemoved++
}
})
newRoutes.RangeAll(func(alias string, newRoute *R.Route) {
oldRoute, exists := oldRoutes.Load(alias)
if exists {
b.Add(oldRoute.Stop())
res.nReloaded++
} else {
res.nAdded++
}
b.Add(newRoute.Start())
oldRoutes.Store(alias, newRoute)
})
return
}
// Returns a list of proxy entries for a container.
// Always non-nil.
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries types.RawEntries, _ E.NestedError) {
entries = types.NewProxyEntries()
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
entries = entry.NewProxyEntries()
if p.shouldIgnore(container) {
return
@ -174,7 +98,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
// init entries map for all aliases
for _, a := range container.Aliases {
entries.Store(a, &types.RawEntry{
entries.Store(a, &entry.RawEntry{
Alias: a,
Container: container,
})
@ -186,14 +110,14 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
}
// remove all entries that failed to fill in missing fields
entries.RangeAll(func(_ string, re *types.RawEntry) {
entries.RangeAll(func(_ string, re *entry.RawEntry) {
re.FillMissingFields()
})
return entries, errors.Build().Subject(container.ContainerName)
}
func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEntries, key, val string) (res E.NestedError) {
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)
@ -220,7 +144,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt
}
if lbl.Target == D.WildcardAlias {
// apply label for all aliases
entries.RangeAll(func(a string, e *types.RawEntry) {
entries.RangeAll(func(a string, e *entry.RawEntry) {
if err = D.ApplyLabel(e, lbl); err != nil {
b.Add(err)
}

View file

@ -10,7 +10,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
"github.com/yusing/go-proxy/internal/proxy/entry"
T "github.com/yusing/go-proxy/internal/proxy/fields"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -46,7 +46,7 @@ func TestApplyLabelWildcard(t *testing.T) {
Names: dummyNames,
Labels: map[string]string{
D.LabelAliases: "a,b",
D.LabelIdleTimeout: common.IdleTimeoutDefault,
D.LabelIdleTimeout: "",
D.LabelStopMethod: common.StopMethodDefault,
D.LabelStopSignal: "SIGTERM",
D.LabelStopTimeout: common.StopTimeoutDefault,
@ -62,7 +62,7 @@ func TestApplyLabelWildcard(t *testing.T) {
"proxy.a.middlewares.middleware2.prop3": "value3",
"proxy.a.middlewares.middleware2.prop4": "value4",
},
}, ""))
}, client.DefaultDockerHost))
ExpectNoError(t, err.Error())
a, ok := entries.Load("a")
@ -88,8 +88,8 @@ func TestApplyLabelWildcard(t *testing.T) {
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.Container.IdleTimeout, "")
ExpectEqual(t, b.Container.IdleTimeout, "")
ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault)
@ -107,6 +107,7 @@ func TestApplyLabelWildcard(t *testing.T) {
func TestApplyLabelWithAlias(t *testing.T) {
entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
State: "running",
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.a.no_tls_verify": "true",
@ -114,7 +115,7 @@ func TestApplyLabelWithAlias(t *testing.T) {
"proxy.b.port": "1234",
"proxy.c.scheme": "https",
},
}, ""))
}, client.DefaultDockerHost))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
@ -134,6 +135,7 @@ func TestApplyLabelWithAlias(t *testing.T) {
func TestApplyLabelWithRef(t *testing.T) {
entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
State: "running",
Labels: map[string]string{
D.LabelAliases: "a,b,c",
"proxy.#1.host": "localhost",
@ -142,7 +144,7 @@ func TestApplyLabelWithRef(t *testing.T) {
"proxy.#3.port": "1111",
"proxy.#3.scheme": "https",
},
}, "")))
}, client.DefaultDockerHost)))
a, ok := entries.Load("a")
ExpectTrue(t, ok)
b, ok := entries.Load("b")
@ -161,6 +163,7 @@ func TestApplyLabelWithRef(t *testing.T) {
func TestApplyLabelWithRefIndexError(t *testing.T) {
c := D.FromDocker(&types.Container{
Names: dummyNames,
State: "running",
Labels: map[string]string{
D.LabelAliases: "a,b",
"proxy.#1.host": "localhost",
@ -173,6 +176,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
_, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{
Names: dummyNames,
State: "running",
Labels: map[string]string{
D.LabelAliases: "a,b",
"proxy.#0.host": "localhost",
@ -183,7 +187,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) {
}
func TestPublicIPLocalhost(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost)
c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, client.DefaultDockerHost)
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1")
@ -191,7 +195,7 @@ func TestPublicIPLocalhost(t *testing.T) {
}
func TestPublicIPRemote(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375")
c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
@ -218,6 +222,7 @@ func TestPrivateIPLocalhost(t *testing.T) {
func TestPrivateIPRemote(t *testing.T) {
c := D.FromDocker(&types.Container{
Names: dummyNames,
State: "running",
NetworkSettings: &types.SummaryNetworkSettings{
Networks: map[string]*network.EndpointSettings{
"network": {
@ -239,6 +244,7 @@ func TestStreamDefaultValues(t *testing.T) {
privIP := "172.17.0.123"
cont := &types.Container{
Names: []string{"a"},
State: "running",
NetworkSettings: &types.SummaryNetworkSettings{
Networks: map[string]*network.EndpointSettings{
"network": {
@ -256,9 +262,8 @@ func TestStreamDefaultValues(t *testing.T) {
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
entry := Must(P.ValidateEntry(raw))
a := ExpectType[*P.StreamEntry](t, entry)
en := Must(entry.ValidateEntry(raw))
a := ExpectType[*entry.StreamEntry](t, en)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Host, T.Host(privIP))
@ -270,9 +275,8 @@ func TestStreamDefaultValues(t *testing.T) {
c := D.FromDocker(cont, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
entry := Must(P.ValidateEntry(raw))
a := ExpectType[*P.StreamEntry](t, entry)
en := Must(entry.ValidateEntry(raw))
a := ExpectType[*entry.StreamEntry](t, en)
ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp"))
ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp"))
ExpectEqual(t, a.Host, "1.2.3.4")

View file

@ -0,0 +1,109 @@
package provider
import (
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher"
)
type EventHandler struct {
provider *Provider
added []string
removed []string
paused []string
updated []string
errs E.Builder
}
func (provider *Provider) newEventHandler() *EventHandler {
return &EventHandler{
provider: provider,
errs: E.NewBuilder("event errors"),
}
}
func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
oldRoutes := handler.provider.routes
newRoutes, err := handler.provider.LoadRoutesImpl()
if err != nil {
handler.errs.Add(err.Subject("load routes"))
return
}
oldRoutes.RangeAll(func(k string, v *route.Route) {
if !newRoutes.Has(k) {
handler.Remove(v)
}
})
newRoutes.RangeAll(func(k string, newr *route.Route) {
if oldRoutes.Has(k) {
for _, ev := range events {
if handler.match(ev, newr) {
old, ok := oldRoutes.Load(k)
if !ok { // should not happen
panic("race condition")
}
handler.Update(parent, old, newr)
return
}
}
} else {
handler.Add(parent, newr)
}
})
}
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.t {
case ProviderTypeDocker:
return route.Entry.Container.ContainerID == event.ActorID ||
route.Entry.Container.ContainerName == event.ActorName
case ProviderTypeFile:
return true
}
// should never happen
return false
}
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
err := handler.provider.startRoute(parent, route)
if err != nil {
handler.errs.Add(err)
} else {
handler.added = append(handler.added, route.Entry.Alias)
}
}
func (handler *EventHandler) Remove(route *route.Route) {
route.Finish("route removal")
handler.removed = append(handler.removed, route.Entry.Alias)
}
func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) {
oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute)
if err != nil {
handler.errs.Add(err)
} else {
handler.updated = append(handler.updated, newRoute.Entry.Alias)
}
}
func (handler *EventHandler) Log() {
results := E.NewBuilder("event occured")
for _, alias := range handler.added {
results.Addf("added %s", alias)
}
for _, alias := range handler.removed {
results.Addf("removed %s", alias)
}
for _, alias := range handler.updated {
results.Addf("updated %s", alias)
}
results.Add(handler.errs.Build())
if result := results.Build(); result != nil {
handler.provider.l.Info(result)
}
}

View file

@ -7,8 +7,8 @@ import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/types"
U "github.com/yusing/go-proxy/internal/utils"
W "github.com/yusing/go-proxy/internal/watcher"
)
@ -42,38 +42,13 @@ func (p FileProvider) String() string {
return p.fileName
}
func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
newRoutes, err := p.LoadRoutesImpl()
if err != nil {
b.Add(err)
return
}
res.nRemoved = newRoutes.Size()
routes.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
})
routes.Clear()
newRoutes.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Start())
})
res.nAdded = newRoutes.Size()
routes.MergeFrom(newRoutes)
return
}
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
routes = R.NewRoutes()
b := E.NewBuilder("file %q validation failure", p.fileName)
defer b.To(&res)
entries := types.NewProxyEntries()
entries := entry.NewProxyEntries()
data, err := E.Check(os.ReadFile(p.path))
if err != nil {

View file

@ -1,14 +1,16 @@
package provider
import (
"context"
"fmt"
"path"
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
W "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events"
)
type (
@ -19,18 +21,14 @@ type (
t ProviderType
routes R.Routes
watcher W.Watcher
watcherTask common.Task
watcherCancel context.CancelFunc
watcher W.Watcher
l *logrus.Entry
}
ProviderImpl interface {
fmt.Stringer
NewWatcher() W.Watcher
// even returns error, routes must be non-nil
LoadRoutesImpl() (R.Routes, E.NestedError)
OnEvent(event W.Event, routes R.Routes) EventResult
String() string
}
ProviderType string
ProviderStats struct {
@ -38,17 +36,13 @@ type (
NumStreams int `json:"num_streams"`
Type ProviderType `json:"type"`
}
EventResult struct {
nAdded int
nRemoved int
nReloaded int
err E.NestedError
}
)
const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
providerEventFlushInterval = 500 * time.Millisecond
)
func newProvider(name string, t ProviderType) *Provider {
@ -106,32 +100,48 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) StartAllRoutes() (res E.NestedError) {
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
subtask := parent.Subtask("route %s", r.Entry.Alias)
err := r.Start(subtask)
if err != nil {
p.routes.Delete(r.Entry.Alias)
subtask.Finish(err.String()) // just to ensure
return err
} else {
subtask.OnComplete("del from provider", func() {
p.routes.Delete(r.Entry.Alias)
})
}
return nil
}
// Start implements task.TaskStarter.
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
errors := E.NewBuilder("errors starting routes")
defer errors.To(&res)
// start watcher no matter load success or not
go p.watchEvents()
// routes and event queue will stop on parent cancel
providerTask := configSubtask
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Start().Subject(r))
errors.Add(p.startRoute(providerTask, r))
})
return
}
func (p *Provider) StopAllRoutes() (res E.NestedError) {
if p.watcherCancel != nil {
p.watcherCancel()
p.watcherCancel = nil
}
errors := E.NewBuilder("errors stopping routes")
defer errors.To(&res)
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Stop().Subject(r))
})
p.routes.Clear()
eventQueue := events.NewEventQueue(
providerTask,
providerEventFlushInterval,
func(flushTask task.Task, events []events.Event) {
handler := p.newEventHandler()
// routes' lifetime should follow the provider's lifetime
handler.Handle(providerTask, events)
handler.Log()
flushTask.Finish("events flushed")
},
func(err E.NestedError) {
p.l.Error(err)
},
)
eventQueue.Start(p.watcher.Events(providerTask.Context()))
return
}
@ -147,7 +157,6 @@ func (p *Provider) LoadRoutes() E.NestedError {
var err E.NestedError
p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 {
p.l.Infof("loaded %d routes", p.routes.Size())
return err
}
if err == nil {
@ -156,13 +165,14 @@ func (p *Provider) LoadRoutes() E.NestedError {
return E.FailWith("loading routes", err)
}
func (p *Provider) NumRoutes() int {
return p.routes.Size()
}
func (p *Provider) Statistics() ProviderStats {
numRPs := 0
numStreams := 0
p.routes.RangeAll(func(_ string, r *R.Route) {
if !r.Started() {
return
}
switch r.Type {
case R.RouteTypeReverseProxy:
numRPs++
@ -176,34 +186,3 @@ func (p *Provider) Statistics() ProviderStats {
Type: p.t,
}
}
func (p *Provider) watchEvents() {
p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name)
defer p.watcherTask.Finished()
events, errs := p.watcher.Events(p.watcherTask.Context())
l := p.l.WithField("module", "watcher")
for {
select {
case <-p.watcherTask.Context().Done():
return
case event := <-events:
task := p.watcherTask.Subtask("%s event %s", event.Type, event)
l.Infof("%s event %q", event.Type, event)
res := p.OnEvent(event, p.routes)
task.Finished()
if res.nAdded+res.nRemoved+res.nReloaded > 0 {
l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded)
}
if res.err != nil {
l.Error(res.err)
}
case err := <-errs:
if err == nil || err.Is(context.Canceled) {
continue
}
l.Errorf("watcher error: %s", err)
}
}
}

View file

@ -4,8 +4,8 @@ import (
"github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
"github.com/yusing/go-proxy/internal/types"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@ -16,16 +16,16 @@ type (
_ U.NoCopy
impl
Type RouteType
Entry *types.RawEntry
Entry *entry.RawEntry
}
Routes = F.Map[string, *Route]
impl interface {
Start() E.NestedError
Stop() E.NestedError
Started() bool
entry.Entry
task.TaskStarter
task.TaskFinisher
String() string
URL() url.URL
TargetURL() url.URL
}
)
@ -44,8 +44,8 @@ func (rt *Route) Container() *docker.Container {
return rt.Entry.Container
}
func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
entry, err := P.ValidateEntry(en)
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
en, err := entry.ValidateEntry(raw)
if err != nil {
return nil, err
}
@ -53,11 +53,11 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
var t RouteType
var rt impl
switch e := entry.(type) {
case *P.StreamEntry:
switch e := en.(type) {
case *entry.StreamEntry:
t = RouteTypeStream
rt, err = NewStreamRoute(e)
case *P.ReverseProxyEntry:
case *entry.ReverseProxyEntry:
t = RouteTypeReverseProxy
rt, err = NewHTTPRoute(e)
default:
@ -69,19 +69,21 @@ func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
return &Route{
impl: rt,
Type: t,
Entry: en,
Entry: raw,
}, nil
}
func FromEntries(entries types.RawEntries) (Routes, E.NestedError) {
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()
entries.RangeAll(func(alias string, entry *types.RawEntry) {
entry.Alias = alias
r, err := NewRoute(entry)
entries.RangeAllParallel(func(alias string, en *entry.RawEntry) {
en.Alias = alias
r, err := NewRoute(en)
if err != nil {
b.Add(err.Subject(alias))
} else if entry.ShouldNotServe(r) {
return
} else {
routes.Store(alias, r)
}

View file

@ -4,169 +4,141 @@ import (
"context"
"errors"
"fmt"
"net"
stdNet "net"
"sync"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type StreamRoute struct {
*P.StreamEntry
StreamImpl `json:"-"`
*entry.StreamEntry
net.Stream `json:"-"`
HealthMon health.HealthMonitor `json:"health"`
url url.URL
task common.Task
cancel context.CancelFunc
done chan struct{}
task task.Task
l logrus.FieldLogger
mu sync.Mutex
}
type StreamImpl interface {
Setup() error
Accept() (any, error)
Handle(conn any) error
CloseListeners()
String() string
}
var streamRoutes = F.NewMapOf[string, *StreamRoute]()
var (
streamRoutes = F.NewMapOf[string, *StreamRoute]()
streamRoutesMu sync.Mutex
)
func GetStreamProxies() F.Map[string, *StreamRoute] {
return streamRoutes
}
func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
// TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
}
url, err := url.ParseURL(fmt.Sprintf("%s://%s:%d", entry.Scheme.ProxyScheme, entry.Host, entry.Port.ProxyPort))
if err != nil {
// !! should not happen
panic(err)
}
base := &StreamRoute{
return &StreamRoute{
StreamEntry: entry,
url: url,
}
if entry.Scheme.ListeningScheme.IsTCP() {
base.StreamImpl = NewTCPRoute(base)
} else {
base.StreamImpl = NewUDPRoute(base)
}
base.l = logrus.WithField("route", base.StreamImpl)
return base, nil
task: task.DummyTask(),
}, nil
}
func (r *StreamRoute) Finish(reason string) {
r.task.Finish(reason)
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("stream %s", r.Alias)
}
func (r *StreamRoute) URL() url.URL {
return r.url
}
func (r *StreamRoute) Start() E.NestedError {
r.mu.Lock()
defer r.mu.Unlock()
if r.Port.ProxyPort == PT.NoPort || r.task != nil {
// Start implements task.TaskStarter.
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
}
r.task, r.cancel = common.NewTaskWithCancel(r.String())
streamRoutesMu.Lock()
defer streamRoutesMu.Unlock()
if r.HealthCheck.Disabled && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.HealthCheck.Disabled = true
}
if r.Scheme.ListeningScheme.IsTCP() {
r.Stream = NewTCPRoute(r)
} else {
r.Stream = NewUDPRoute(r)
}
r.l = logrus.WithField("route", r.Stream.String())
switch {
case entry.UseIdleWatcher(r):
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
if err != nil {
return err
}
r.Stream = waker
r.HealthMon = waker
case entry.UseHealthCheck(r):
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
}
r.task = providerSubtask
r.task.OnComplete("stop stream", r.CloseListeners)
if err := r.Setup(); err != nil {
return E.FailWith("setup", err)
}
r.done = make(chan struct{})
r.l.Infof("listening on port %d", r.Port.ListeningPort)
go r.acceptConnections()
if !r.Healthcheck.Disabled {
r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck)
r.HealthMon.Start()
if r.HealthMon != nil {
r.HealthMon.Start(r.task.Subtask("health monitor"))
}
streamRoutes.Store(string(r.Alias), r)
return nil
}
func (r *StreamRoute) Stop() E.NestedError {
r.mu.Lock()
defer r.mu.Unlock()
if r.task == nil {
return nil
}
streamRoutes.Delete(string(r.Alias))
if r.HealthMon != nil {
r.HealthMon.Stop()
r.HealthMon = nil
}
r.cancel()
r.CloseListeners()
<-r.done
return nil
}
func (r *StreamRoute) Started() bool {
return r.task != nil
}
func (r *StreamRoute) acceptConnections() {
var connWg sync.WaitGroup
task := r.task.Subtask("%s accept connections", r.String())
defer func() {
connWg.Wait()
task.Finished()
r.task.Finished()
r.task, r.cancel = nil, nil
close(r.done)
r.done = nil
}()
for {
select {
case <-task.Context().Done():
case <-r.task.Context().Done():
return
default:
conn, err := r.Accept()
if err != nil {
select {
case <-task.Context().Done():
case <-r.task.Context().Done():
return
default:
var nErr *net.OpError
var nErr *stdNet.OpError
ok := errors.As(err, &nErr)
if !(ok && nErr.Timeout()) {
r.l.Error(err)
r.l.Error("accept connection error: ", err)
r.task.Finish(err.Error())
return
}
continue
}
}
connWg.Add(1)
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err)
connTask.Finish(err.Error())
} else {
connTask.Finish("connection closed")
}
connWg.Done()
conn.Close()
}()
}
}

View file

@ -6,6 +6,7 @@ import (
"net"
"time"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@ -21,7 +22,7 @@ type (
}
)
func NewTCPRoute(base *StreamRoute) StreamImpl {
func NewTCPRoute(base *StreamRoute) *TCPRoute {
return &TCPRoute{StreamRoute: base}
}
@ -36,19 +37,16 @@ func (route *TCPRoute) Setup() error {
return nil
}
func (route *TCPRoute) Accept() (any, error) {
func (route *TCPRoute) Accept() (types.StreamConn, error) {
route.listener.SetDeadline(time.Now().Add(time.Second))
return route.listener.Accept()
}
func (route *TCPRoute) Handle(c any) error {
func (route *TCPRoute) Handle(c types.StreamConn) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
go func() {
<-route.task.Context().Done()
clientConn.Close()
}()
route.task.OnComplete("close conn", func() { clientConn.Close() })
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
@ -70,5 +68,4 @@ func (route *TCPRoute) CloseListeners() {
return
}
route.listener.Close()
route.listener = nil
}

View file

@ -1,11 +1,13 @@
package route
import (
"errors"
"fmt"
"io"
"net"
"time"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@ -33,7 +35,7 @@ var NewUDPConnMap = F.NewMap[UDPConnMap]
const udpBufferSize = 8192
func NewUDPRoute(base *StreamRoute) StreamImpl {
func NewUDPRoute(base *StreamRoute) *UDPRoute {
return &UDPRoute{
StreamRoute: base,
connMap: NewUDPConnMap(),
@ -64,7 +66,7 @@ func (route *UDPRoute) Setup() error {
return nil
}
func (route *UDPRoute) Accept() (any, error) {
func (route *UDPRoute) Accept() (types.StreamConn, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
@ -104,7 +106,7 @@ func (route *UDPRoute) Accept() (any, error) {
return conn, err
}
func (route *UDPRoute) Handle(c any) error {
func (route *UDPRoute) Handle(c types.StreamConn) error {
conn := c.(*UDPConn)
err := conn.Start()
route.connMap.Delete(conn.key)
@ -114,19 +116,25 @@ func (route *UDPRoute) Handle(c any) error {
func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
route.listeningConn = nil
}
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
if err := conn.src.Close(); err != nil {
route.l.Errorf("error closing src conn: %s", err)
}
if err := conn.dst.Close(); err != nil {
route.l.Error("error closing dst conn: %s", err)
if err := conn.Close(); err != nil {
route.l.Errorf("error closing conn: %s", err)
}
})
route.connMap.Clear()
}
// Close implements types.StreamConn
func (conn *UDPConn) Close() error {
return errors.Join(conn.src.Close(), conn.dst.Close())
}
// RemoteAddr implements types.StreamConn
func (conn *UDPConn) RemoteAddr() net.Addr {
return conn.src.RemoteAddr()
}
type sourceRWCloser struct {
server *net.UDPConn
*net.UDPConn

View file

@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/tls"
"errors"
"log"
@ -9,8 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"golang.org/x/net/context"
"github.com/yusing/go-proxy/internal/task"
)
type Server struct {
@ -21,7 +21,8 @@ type Server struct {
httpStarted bool
httpsStarted bool
startTime time.Time
task common.Task
task task.Task
}
type Options struct {
@ -84,7 +85,7 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider,
http: httpSer,
https: httpsSer,
task: common.GlobalTask(opt.Name + " server"),
task: task.GlobalTask(opt.Name + " server"),
}
}
@ -115,11 +116,7 @@ func (s *Server) Start() {
}()
}
go func() {
<-s.task.Context().Done()
s.stop()
s.task.Finished()
}()
s.task.OnComplete("stop server", s.stop)
}
func (s *Server) stop() {
@ -127,16 +124,13 @@ func (s *Server) stop() {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(ctx))
s.handleErr("http", s.http.Shutdown(s.task.Context()))
s.httpStarted = false
}
if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(ctx))
s.handleErr("https", s.https.Shutdown(s.task.Context()))
s.httpsStarted = false
}
}
@ -147,7 +141,7 @@ func (s *Server) Uptime() time.Duration {
func (s *Server) handleErr(scheme string, err error) {
switch {
case err == nil, errors.Is(err, http.ErrServerClosed):
case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled):
return
default:
logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err)

View file

@ -0,0 +1,43 @@
package task
import "context"
type dummyTask struct{}
func DummyTask() (_ Task) {
return
}
// Context implements Task.
func (d dummyTask) Context() context.Context {
panic("call of dummyTask.Context")
}
// Finish implements Task.
func (d dummyTask) Finish() {}
// Name implements Task.
func (d dummyTask) Name() string {
return "Dummy Task"
}
// OnComplete implements Task.
func (d dummyTask) OnComplete(about string, fn func()) {
panic("call of dummyTask.OnComplete")
}
// Parent implements Task.
func (d dummyTask) Parent() Task {
panic("call of dummyTask.Parent")
}
// Subtask implements Task.
func (d dummyTask) Subtask(usageFmt string, args ...any) Task {
panic("call of dummyTask.Subtask")
}
// Wait implements Task.
func (d dummyTask) Wait() {}
// WaitSubTasks implements Task.
func (d dummyTask) WaitSubTasks() {}

310
internal/task/task.go Normal file
View file

@ -0,0 +1,310 @@
package task
import (
"context"
"errors"
"fmt"
"runtime"
"strings"
"sync"
"time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
var globalTask = createGlobalTask()
func createGlobalTask() (t *task) {
t = new(task)
t.name = "root"
t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.subtasks = xsync.NewMapOf[*task, struct{}]()
return
}
type (
// Task controls objects' lifetime.
//
// Task must be initialized, use DummyTask if the task is not yet started.
//
// Objects that uses a task should implement the TaskStarter and the TaskFinisher interface.
//
// When passing a Task object to another function,
// it must be a sub-task of the current task,
// in name of "`currentTaskName`Subtask"
//
// Use Task.Finish to stop all subtasks of the task.
Task interface {
TaskFinisher
// Name returns the name of the task.
Name() string
// Context returns the context associated with the task. This context is
// canceled when Finish is called.
Context() context.Context
// FinishCause returns the reason / error that caused the task to be finished.
FinishCause() error
// Parent returns the parent task of the current task.
Parent() Task
// Subtask returns a new subtask with the given name, derived from the parent's context.
//
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
//
// This should not be called after Finish, Wait, or WaitSubTasks is called.
Subtask(usageFmt string, args ...any) Task
// OnComplete calls fn when the task and all subtasks are finished.
//
// It cannot be called after Finish or Wait is called.
OnComplete(about string, fn func())
// Wait waits for all subtasks, itself and all OnComplete to finish.
//
// It must be called only after Finish is called.
Wait()
// WaitSubTasks waits for all subtasks of the task to finish.
//
// No more subtasks can be added after this call.
//
// It can be called before Finish is called.
WaitSubTasks()
}
TaskStarter interface {
// Start starts the object that implements TaskStarter,
// and returns an error if it fails to start.
//
// The task passed must be a subtask of the caller task.
//
// callerSubtask.Finish must be called when start fails or the object is finished.
Start(callerSubtask Task) E.NestedError
}
TaskFinisher interface {
// Finish marks the task as finished by cancelling its context.
//
// Then call Wait to wait for all subtasks and OnComplete of the task to finish.
//
// Note that it will also cancel all subtasks.
Finish(reason string)
}
task struct {
ctx context.Context
cancel context.CancelCauseFunc
parent *task
subtasks *xsync.MapOf[*task, struct{}]
name, line string
subTasksWg, onCompleteWg sync.WaitGroup
}
)
var (
ErrProgramExiting = errors.New("program exiting")
ErrTaskCancelled = errors.New("task cancelled")
)
// GlobalTask returns a new Task with the given name, derived from the global context.
func GlobalTask(format string, args ...any) Task {
return globalTask.Subtask(format, args...)
}
// DebugTaskMap returns a map[string]any representation of the global task tree.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func DebugTaskMap() map[string]any {
return globalTask.serialize()
}
// CancelGlobalContext cancels the global task context, which will cause all tasks
// created to be canceled. This should be called before exiting the program
// to ensure that all tasks are properly cleaned up.
func CancelGlobalContext() {
globalTask.cancel(ErrProgramExiting)
}
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) {
done := make(chan struct{})
after := time.After(timeout)
go func() {
globalTask.Wait()
close(done)
}()
for {
select {
case <-done:
return
case <-after:
logrus.Warn("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
return
}
}
}
func (t *task) Name() string {
return t.name
}
func (t *task) Context() context.Context {
return t.ctx
}
func (t *task) FinishCause() error {
return context.Cause(t.ctx)
}
func (t *task) Parent() Task {
return t.parent
}
func (t *task) OnComplete(about string, fn func()) {
t.onCompleteWg.Add(1)
var file string
var line int
if common.IsTrace {
_, file, line, _ = runtime.Caller(1)
}
go func() {
defer func() {
if err := recover(); err != nil {
logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err)
}
}()
defer t.onCompleteWg.Done()
t.subTasksWg.Wait()
<-t.ctx.Done()
fn()
logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about)
t.cancel(nil) // ensure resources are released
}()
}
func (t *task) Finish(reason string) {
t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason))
t.Wait()
}
func (t *task) Subtask(format string, args ...any) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
ctx, cancel := context.WithCancelCause(t.ctx)
return t.newSubTask(ctx, cancel, format)
}
func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task {
parent := t
subtask := &task{
ctx: ctx,
cancel: cancel,
name: name,
parent: parent,
subtasks: xsync.NewMapOf[*task, struct{}](),
}
parent.subTasksWg.Add(1)
parent.subtasks.Store(subtask, struct{}{})
if common.IsTrace {
_, file, line, ok := runtime.Caller(3)
if ok {
subtask.line = fmt.Sprintf("%s:%d", file, line)
}
logrus.Tracef("line %s\ntask %q started", subtask.line, name)
go func() {
subtask.Wait()
logrus.Tracef("task %q finished", subtask.Name())
}()
}
go func() {
subtask.Wait()
parent.subtasks.Delete(subtask)
parent.subTasksWg.Done()
}()
return subtask
}
func (t *task) Wait() {
t.subTasksWg.Wait()
if t != globalTask {
<-t.ctx.Done()
}
t.onCompleteWg.Wait()
}
func (t *task) WaitSubTasks() {
t.subTasksWg.Wait()
}
// tree returns a string representation of the task tree, with the given
// prefix prepended to each line. The prefix is used to indent the tree,
// and should be a string of spaces or a similar separator.
//
// The resulting string is suitable for printing to the console, and can be
// used to debug the task tree.
//
// The tree is traversed in a depth-first manner, with each task's name and
// line number (if available) printed on a separate line. The line number is
// only printed if the task was created with a non-empty line argument.
//
// The returned string is not guaranteed to be stable, and may change between
// runs of the program. It is intended for debugging purposes only.
func (t *task) tree(prefix ...string) string {
var sb strings.Builder
var pre string
if len(prefix) > 0 {
pre = prefix[0]
sb.WriteString(pre + "- ")
}
if t.line != "" {
sb.WriteString("line " + t.line + "\n")
}
if len(pre) > 0 {
sb.WriteString(pre + "- ")
}
sb.WriteString(t.Name() + "\n")
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
sb.WriteString(subtask.tree(pre + " "))
return true
})
return sb.String()
}
// serialize returns a map[string]any representation of the task tree.
//
// The map contains the following keys:
// - name: the name of the task
// - line: the line number of the task, if available
// - subtasks: a slice of maps, each representing a subtask
//
// The subtask maps contain the same keys, recursively.
//
// The returned map is suitable for encoding to JSON, and can be used
// to debug the task tree.
//
// The returned map is not guaranteed to be stable, and may change
// between runs of the program. It is intended for debugging purposes
// only.
func (t *task) serialize() map[string]any {
m := make(map[string]any)
m["name"] = t.name
if t.line != "" {
m["line"] = t.line
}
if t.subtasks.Size() > 0 {
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
return true
})
}
return m
}

147
internal/task/task_test.go Normal file
View file

@ -0,0 +1,147 @@
package task_test
import (
"context"
"sync/atomic"
"testing"
"time"
. "github.com/yusing/go-proxy/internal/task"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestTaskCreation(t *testing.T) {
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
ExpectEqual(t, "root-task", rootTask.Name())
ExpectEqual(t, "subtask", subTask.Name())
}
func TestTaskCancellation(t *testing.T) {
defer CancelGlobalContext()
subTaskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
go func() {
subTask.Wait()
close(subTaskDone)
}()
go rootTask.Finish("done")
select {
case <-subTaskDone:
err := subTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(subTask.Context())
ExpectError(t, ErrTaskCancelled, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
}
}
func TestGlobalContextCancellation(t *testing.T) {
taskDone := make(chan struct{})
rootTask := GlobalTask("root-task")
go func() {
rootTask.Wait()
close(taskDone)
}()
CancelGlobalContext()
select {
case <-taskDone:
err := rootTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(rootTask.Context())
ExpectError(t, ErrProgramExiting, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
}
}
func TestOnComplete(t *testing.T) {
defer CancelGlobalContext()
task := GlobalTask("test")
var value atomic.Int32
task.OnComplete("set value", func() {
value.Store(1234)
})
task.Finish("done")
ExpectEqual(t, value.Load(), 1234)
}
func TestGlobalContextWait(t *testing.T) {
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
finished1, finished2 := false, false
subTask1 := rootTask.Subtask("subtask1")
subTask2 := rootTask.Subtask("subtask2")
subTask1.OnComplete("set finished", func() {
finished1 = true
})
subTask2.OnComplete("set finished", func() {
finished2 = true
})
go func() {
time.Sleep(500 * time.Millisecond)
subTask1.Finish("done")
}()
go func() {
time.Sleep(500 * time.Millisecond)
subTask2.Finish("done")
}()
go func() {
subTask1.Wait()
subTask2.Wait()
rootTask.Finish("done")
}()
GlobalContextWait(1 * time.Second)
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context()))
}
func TestTimeoutOnGlobalContextWait(t *testing.T) {
defer CancelGlobalContext()
rootTask := GlobalTask("root-task")
subTask := rootTask.Subtask("subtask")
done := make(chan struct{})
go func() {
GlobalContextWait(500 * time.Millisecond)
close(done)
}()
select {
case <-done:
t.Fatal("GlobalContextWait should have timed out")
case <-time.After(200 * time.Millisecond):
}
// Ensure clean exit
subTask.Finish("exit")
}
func TestGlobalContextCancel(t *testing.T) {
}

View file

@ -1,18 +0,0 @@
package types
type Config struct {
Providers ProxyProviders `json:"providers" yaml:",flow"`
AutoCert AutoCertConfig `json:"autocert" yaml:",flow"`
ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"`
MatchDomains []string `json:"match_domains" yaml:"match_domains"`
TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"`
RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"`
}
func DefaultConfig() *Config {
return &Config{
Providers: ProxyProviders{},
TimeoutShutdown: 3,
RedirectToHTTPS: false,
}
}

View file

@ -1,6 +0,0 @@
package types
type ProxyProviders struct {
Files []string `json:"include" yaml:"include"` // docker, file
Docker map[string]string `json:"docker" yaml:"docker"`
}

View file

@ -23,7 +23,7 @@ func IgnoreError[Result any](r Result, _ error) Result {
func ExpectNoError(t *testing.T, err error) {
t.Helper()
if err != nil && !reflect.ValueOf(err).IsNil() {
t.Errorf("expected err=nil, got %s", err.Error())
t.Errorf("expected err=nil, got %s", err)
t.FailNow()
}
}
@ -31,7 +31,7 @@ func ExpectNoError(t *testing.T, err error) {
func ExpectError(t *testing.T, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("expected err %s, got %s", expected.Error(), err.Error())
t.Errorf("expected err %s, got %s", expected, err)
t.FailNow()
}
}
@ -39,7 +39,7 @@ func ExpectError(t *testing.T, expected error, err error) {
func ExpectError2(t *testing.T, input any, expected error, err error) {
t.Helper()
if !errors.Is(err, expected) {
t.Errorf("%v: expected err %s, got %s", input, expected.Error(), err.Error())
t.Errorf("%v: expected err %s, got %s", input, expected, err)
t.FailNow()
}
}

View file

@ -15,8 +15,9 @@ import (
type (
DockerWatcher struct {
host string
client D.Client
host string
client D.Client
clientOwned bool
logrus.FieldLogger
}
DockerListOptions = docker_events.ListOptions
@ -44,10 +45,11 @@ func DockerrFilterContainer(nameOrID string) filters.KeyValuePair {
func NewDockerWatcher(host string) DockerWatcher {
return DockerWatcher{
host: host,
clientOwned: true,
FieldLogger: (logrus.
WithField("module", "docker_watcher").
WithField("host", host)),
host: host,
}
}
@ -72,7 +74,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
defer close(errCh)
defer func() {
if w.client.Connected() {
if w.clientOwned && w.client.Connected() {
w.client.Close()
}
}()

View file

@ -0,0 +1,91 @@
package events
import (
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
)
type (
EventQueue struct {
task task.Task
queue []Event
ticker *time.Ticker
onFlush OnFlushFunc
onError OnErrorFunc
}
OnFlushFunc = func(flushTask task.Task, events []Event)
OnErrorFunc = func(err E.NestedError)
)
const eventQueueCapacity = 10
// NewEventQueue returns a new EventQueue with the given
// queueTask, flushInterval, onFlush and onError.
//
// The returned EventQueue will start a goroutine to flush events in the queue
// when the flushInterval is reached.
//
// The onFlush function is called when the flushInterval is reached and the queue is not empty,
//
// The onError function is called when an error received from the errCh,
// or panic occurs in the onFlush function. Panic will cause a E.ErrPanicRecv error.
//
// flushTask.Finish must be called after the flush is done,
// but the onFlush function can return earlier (e.g. run in another goroutine).
//
// If task is cancelled before the flushInterval is reached, the events in queue will be discarded.
func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
return &EventQueue{
task: parent.Subtask("event queue"),
queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval),
onFlush: onFlush,
onError: onError,
}
}
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) {
go func() {
defer e.ticker.Stop()
for {
select {
case <-e.task.Context().Done():
e.task.Finish(e.task.FinishCause().Error())
return
case <-e.ticker.C:
if len(e.queue) > 0 {
flushTask := e.task.Subtask("flush events")
queue := e.queue
e.queue = make([]Event, 0, eventQueueCapacity)
go func() {
defer func() {
if err := recover(); err != nil {
e.onError(E.PanicRecv("panic in onFlush %s", err))
}
}()
e.onFlush(flushTask, queue)
}()
flushTask.Wait()
}
case event, ok := <-eventCh:
e.queue = append(e.queue, event)
if !ok {
return
}
case err := <-errCh:
if err != nil {
e.onError(err)
}
}
}
}()
}
// Wait waits for all events to be flushed and the task to finish.
//
// It is safe to call this method multiple times.
func (e *EventQueue) Wait() {
e.task.Wait()
}

View file

@ -74,7 +74,7 @@ var actionNameMap = func() (m map[Action]string) {
}()
func (e Event) String() string {
return fmt.Sprintf("%s %s", e.ActorName, e.Action)
return fmt.Sprintf("%s %s", e.Action, e.ActorName)
}
func (a Action) String() string {

View file

@ -5,7 +5,6 @@ import (
"errors"
"net/http"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types"
)
@ -15,10 +14,10 @@ type HTTPHealthMonitor struct {
pinger *http.Client
}
func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor {
func NewHTTPHealthMonitor(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) *HTTPHealthMonitor {
mon := new(HTTPHealthMonitor)
mon.monitor = newMonitor(task, url, config, mon.checkHealth)
mon.pinger = &http.Client{Timeout: config.Timeout}
mon.monitor = newMonitor(url, config, mon.CheckHealth)
mon.pinger = &http.Client{Timeout: config.Timeout, Transport: transport}
if config.UseGet {
mon.method = http.MethodGet
} else {
@ -27,19 +26,26 @@ func NewHTTPHealthMonitor(task common.Task, url types.URL, config *HealthCheckCo
return mon
}
func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) {
func NewHTTPHealthChecker(url types.URL, config *HealthCheckConfig, transport http.RoundTripper) HealthChecker {
return NewHTTPHealthMonitor(url, config, transport)
}
func (mon *HTTPHealthMonitor) CheckHealth() (healthy bool, detail string, err error) {
ctx, cancel := mon.ContextWithTimeout("ping request timed out")
defer cancel()
req, reqErr := http.NewRequestWithContext(
mon.task.Context(),
ctx,
mon.method,
mon.url.JoinPath(mon.config.Path).String(),
mon.url.Load().JoinPath(mon.config.Path).String(),
nil,
)
if reqErr != nil {
err = reqErr
return
}
req.Header.Set("Connection", "close")
req.Header.Set("Connection", "close")
resp, respErr := mon.pinger.Do(req)
if respErr == nil {
resp.Body.Close()

View file

@ -2,78 +2,93 @@ package health
import (
"context"
"encoding/json"
"errors"
"sync"
"fmt"
"time"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
HealthMonitor interface {
Start()
Stop()
task.TaskStarter
task.TaskFinisher
fmt.Stringer
json.Marshaler
Status() Status
Uptime() time.Duration
Name() string
String() string
MarshalJSON() ([]byte, error)
}
HealthChecker interface {
CheckHealth() (healthy bool, detail string, err error)
URL() types.URL
Config() *HealthCheckConfig
UpdateURL(url types.URL)
}
HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct {
service string
config *HealthCheckConfig
url types.URL
url U.AtomicValue[types.URL]
status U.AtomicValue[Status]
checkHealth HealthCheckFunc
startTime time.Time
task common.Task
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
task task.Task
}
)
var monMap = F.NewMapOf[string, HealthMonitor]()
func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
service := task.Name()
task, cancel := task.SubtaskWithCancel("Health monitor for %s", service)
func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
mon := &monitor{
service: service,
config: config,
url: url,
checkHealth: healthCheckFunc,
startTime: time.Now(),
task: task,
cancel: cancel,
done: make(chan struct{}),
task: task.DummyTask(),
}
mon.url.Store(url)
mon.status.Store(StatusHealthy)
return mon
}
func Inspect(name string) (HealthMonitor, bool) {
return monMap.Load(name)
func Inspect(service string) (HealthMonitor, bool) {
return monMap.Load(service)
}
func (mon *monitor) Start() {
defer monMap.Store(mon.task.Name(), mon)
func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cancel context.CancelFunc) {
if mon.task != nil {
return context.WithTimeoutCause(mon.task.Context(), mon.config.Timeout, errors.New(cause))
} else {
return context.WithTimeoutCause(context.Background(), mon.config.Timeout, errors.New(cause))
}
}
// Start implements task.TaskStarter.
func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
mon.service = routeSubtask.Parent().Name()
mon.task = routeSubtask
if err := mon.checkUpdateHealth(); err != nil {
mon.task.Finish(fmt.Sprintf("healthchecker %s failure: %s", mon.service, err))
return err
}
go func() {
defer close(mon.done)
defer mon.task.Finished()
defer func() {
monMap.Delete(mon.task.Name())
if mon.status.Load() != StatusError {
mon.status.Store(StatusUnknown)
}
mon.task.Finish(mon.task.FinishCause().Error())
}()
ok := mon.checkUpdateHealth()
if !ok {
return
}
monMap.Store(mon.service, mon)
ticker := time.NewTicker(mon.config.Interval)
defer ticker.Stop()
@ -83,48 +98,61 @@ func (mon *monitor) Start() {
case <-mon.task.Context().Done():
return
case <-ticker.C:
ok = mon.checkUpdateHealth()
if !ok {
err := mon.checkUpdateHealth()
if err != nil {
logger.Errorf("healthchecker %s failure: %s", mon.service, err)
return
}
}
}
}()
return nil
}
func (mon *monitor) Stop() {
monMap.Delete(mon.task.Name())
mon.mu.Lock()
defer mon.mu.Unlock()
if mon.cancel == nil {
return
}
mon.cancel()
<-mon.done
mon.cancel = nil
mon.status.Store(StatusUnknown)
// Finish implements task.TaskFinisher.
func (mon *monitor) Finish(reason string) {
mon.task.Finish(reason)
}
// UpdateURL implements HealthChecker.
func (mon *monitor) UpdateURL(url types.URL) {
mon.url.Store(url)
}
// URL implements HealthChecker.
func (mon *monitor) URL() types.URL {
return mon.url.Load()
}
// Config implements HealthChecker.
func (mon *monitor) Config() *HealthCheckConfig {
return mon.config
}
// Status implements HealthMonitor.
func (mon *monitor) Status() Status {
return mon.status.Load()
}
// Uptime implements HealthMonitor.
func (mon *monitor) Uptime() time.Duration {
return time.Since(mon.startTime)
}
// Name implements HealthMonitor.
func (mon *monitor) Name() string {
if mon.task == nil {
return ""
}
return mon.task.Name()
}
// String implements fmt.Stringer of HealthMonitor.
func (mon *monitor) String() string {
return mon.Name()
}
// MarshalJSON implements json.Marshaler of HealthMonitor.
func (mon *monitor) MarshalJSON() ([]byte, error) {
return (&JSONRepresentation{
Name: mon.service,
@ -132,19 +160,19 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
Status: mon.status.Load(),
Started: mon.startTime,
Uptime: mon.Uptime(),
URL: mon.url,
URL: mon.url.Load(),
}).MarshalJSON()
}
func (mon *monitor) checkUpdateHealth() (hasError bool) {
func (mon *monitor) checkUpdateHealth() E.NestedError {
healthy, detail, err := mon.checkHealth()
if err != nil {
defer mon.task.Finish(err.Error())
mon.status.Store(StatusError)
if !errors.Is(err, context.Canceled) {
logger.Errorf("%s failed to check health: %s", mon.service, err)
return E.Failure("check health").With(err)
}
mon.Stop()
return false
return nil
}
var status Status
if healthy {
@ -160,5 +188,5 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
}
}
return true
return nil
}

View file

@ -3,7 +3,6 @@ package health
import (
"net"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types"
)
@ -14,9 +13,9 @@ type (
}
)
func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckConfig) HealthMonitor {
func NewRawHealthMonitor(url types.URL, config *HealthCheckConfig) *RawHealthMonitor {
mon := new(RawHealthMonitor)
mon.monitor = newMonitor(task, url, config, mon.checkAvail)
mon.monitor = newMonitor(url, config, mon.CheckHealth)
mon.dialer = &net.Dialer{
Timeout: config.Timeout,
FallbackDelay: -1,
@ -24,14 +23,22 @@ func NewRawHealthMonitor(task common.Task, url types.URL, config *HealthCheckCon
return mon
}
func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) {
conn, dialErr := mon.dialer.DialContext(mon.task.Context(), mon.url.Scheme, mon.url.Host)
func NewRawHealthChecker(url types.URL, config *HealthCheckConfig) HealthChecker {
return NewRawHealthMonitor(url, config)
}
func (mon *RawHealthMonitor) CheckHealth() (healthy bool, detail string, err error) {
ctx, cancel := mon.ContextWithTimeout("ping request timed out")
defer cancel()
url := mon.url.Load()
conn, dialErr := mon.dialer.DialContext(ctx, url.Scheme, url.Host)
if dialErr != nil {
detail = dialErr.Error()
/* trunk-ignore(golangci-lint/nilerr) */
return
}
conn.Close()
avail = true
healthy = true
return
}