Improved healthcheck, idlewatcher support for loadbalanced routes, bug fixes

This commit is contained in:
yusing 2024-10-15 15:34:27 +08:00
parent 53fa28ae77
commit f4d532598c
34 changed files with 568 additions and 423 deletions

2
.gitignore vendored
View file

@ -1,6 +1,8 @@
compose.yml compose.yml
*.compose.yml *.compose.yml
config
certs
config*/ config*/
certs*/ certs*/
bin/ bin/

8
go.mod
View file

@ -41,12 +41,12 @@ require (
github.com/ovh/go-ovh v1.6.0 // indirect github.com/ovh/go-ovh v1.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.30.0 // indirect go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect go.opentelemetry.io/otel/metric v1.31.0 // indirect
go.opentelemetry.io/otel/sdk v1.30.0 // indirect go.opentelemetry.io/otel/sdk v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/crypto v0.28.0 // indirect golang.org/x/crypto v0.28.0 // indirect
golang.org/x/mod v0.21.0 // indirect golang.org/x/mod v0.21.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect

16
go.sum
View file

@ -96,20 +96,20 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 h1:ZIg3ZT/aQ7AfKqdwp7ECpOK6vHqquXXuyTjIO8ZdmPs= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0/go.mod h1:DQAwmETtZV00skUwgD6+0U89g80NKsJE3DCKeLLPQMI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts= go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc= go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0 h1:lsInsfvhVIfOI6qHVyysXMNDnjO9Npvl7tlDPJFBVd4= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0 h1:lsInsfvhVIfOI6qHVyysXMNDnjO9Npvl7tlDPJFBVd4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0/go.mod h1:KQsVNh4OjgjTG0G6EiNi1jVpnaeeKsKMRwbLN+f1+8M= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0/go.mod h1:KQsVNh4OjgjTG0G6EiNi1jVpnaeeKsKMRwbLN+f1+8M=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 h1:umZgi92IyxfXd/l4kaDhnKgY8rnN/cZcF1LKc6I8OQ8= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 h1:umZgi92IyxfXd/l4kaDhnKgY8rnN/cZcF1LKc6I8OQ8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0/go.mod h1:4lVs6obhSVRb1EW5FhOuBTyiQhtRtAnnva9vD3yRfq8= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0/go.mod h1:4lVs6obhSVRb1EW5FhOuBTyiQhtRtAnnva9vD3yRfq8=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w= go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ= go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
go.opentelemetry.io/otel/sdk v1.30.0 h1:cHdik6irO49R5IysVhdn8oaiR9m8XluDaJAs4DfOrYE= go.opentelemetry.io/otel/sdk v1.30.0 h1:cHdik6irO49R5IysVhdn8oaiR9m8XluDaJAs4DfOrYE=
go.opentelemetry.io/otel/sdk v1.30.0/go.mod h1:p14X4Ok8S+sygzblytT1nqG98QG2KYKv++HE0LY/mhg= go.opentelemetry.io/otel/sdk v1.30.0/go.mod h1:p14X4Ok8S+sygzblytT1nqG98QG2KYKv++HE0LY/mhg=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc= go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o= go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View file

@ -15,10 +15,15 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
return return
} }
status, ok := health.Inspect(target) result, ok := health.Inspect(target)
if !ok { if !ok {
HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound) HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound)
return return
} }
WriteBody(w, []byte(status.String())) json, err := result.MarshalJSON()
if err != nil {
HandleErr(w, r, err)
return
}
RespondJSON(w, r, json)
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
) )
@ -45,16 +46,7 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
} }
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) { func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias() routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type")))
typeFilter := r.FormValue("type")
if typeFilter != "" {
for k, v := range routes {
if v["type"] != typeFilter {
delete(routes, k)
}
}
}
U.RespondJSON(w, r, routes) U.RespondJSON(w, r, routes)
} }

View file

@ -212,9 +212,9 @@ func GlobalContextWait(timeout time.Duration) {
case <-done: case <-done:
return return
case <-after: case <-after:
logrus.Println("Timeout waiting for these tasks to finish:") logrus.Warnln("Timeout waiting for these tasks to finish:")
tasksMap.Range(func(t *task, _ struct{}) bool { tasksMap.Range(func(t *task, _ struct{}) bool {
logrus.Println(t.tree()) logrus.Warnln(t.tree())
return true return true
}) })
return return

View file

@ -87,9 +87,8 @@ func (cfg *Config) StartProxyProviders() {
func (cfg *Config) WatchChanges() { func (cfg *Config) WatchChanges() {
task := common.NewTask("Config watcher") task := common.NewTask("Config watcher")
defer task.Finished()
go func() { go func() {
defer task.Finished()
for { for {
select { select {
case <-task.Context().Done(): case <-task.Context().Done():

View file

@ -42,25 +42,18 @@ func (cfg *Config) HomepageConfig() homepage.Config {
} }
hpCfg := homepage.NewHomePageConfig() hpCfg := homepage.NewHomePageConfig()
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
if !r.Started() { entry := r.Raw
return
}
entry := r.Entry
if entry.Homepage == nil {
entry.Homepage = &homepage.Item{
Show: r.Entry.IsExplicit || !p.IsExplicitOnly(),
}
}
item := entry.Homepage item := entry.Homepage
if item == nil {
item = new(homepage.Item)
}
if !item.Show && !item.IsEmpty() { if !item.Show && item.IsEmpty() {
item.Show = true item.Show = true
} }
if !item.Show || r.Type != R.RouteTypeReverseProxy { if !item.Show {
return return
} }
@ -73,12 +66,17 @@ func (cfg *Config) HomepageConfig() homepage.Config {
) )
} }
if p.GetType() == PR.ProviderTypeDocker { if r.IsDocker() {
if item.Category == "" { if item.Category == "" {
item.Category = "Docker" item.Category = "Docker"
} }
item.SourceType = string(PR.ProviderTypeDocker) item.SourceType = string(PR.ProviderTypeDocker)
} else if p.GetType() == PR.ProviderTypeFile { } else if r.UseLoadBalance() {
if item.Category == "" {
item.Category = "Load-balanced"
}
item.SourceType = "loadbalancer"
} else {
if item.Category == "" { if item.Category == "" {
item.Category = "Others" item.Category = "Others"
} }
@ -97,28 +95,20 @@ func (cfg *Config) HomepageConfig() homepage.Config {
return hpCfg return hpCfg
} }
func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]U.SerializedObject { func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
routes := make(map[string]U.SerializedObject) routes := make(map[string]any)
if len(typeFilter) == 0 { if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream} typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream}
} }
for _, t := range typeFilter { for _, t := range typeFilter {
switch t { switch t {
case R.RouteTypeReverseProxy: case R.RouteTypeReverseProxy:
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) { R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
obj, err := U.Serialize(r) routes[alias] = r
if err != nil {
panic(err) // should not happen
}
routes[alias] = obj
}) })
case R.RouteTypeStream: case R.RouteTypeStream:
R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) { R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) {
obj, err := U.Serialize(r) routes[alias] = r
if err != nil {
panic(err) // should not happen
}
routes[alias] = obj
}) })
} }
} }

View file

@ -43,7 +43,9 @@ func init() {
select { select {
case <-task.Context().Done(): case <-task.Context().Done():
clientMap.RangeAllParallel(func(_ string, c Client) { clientMap.RangeAllParallel(func(_ string, c Client) {
c.Client.Close() if c.Connected() {
c.Client.Close()
}
}) })
clientMap.Clear() clientMap.Clear()
return return

View file

@ -41,6 +41,8 @@ type (
} }
) )
var DummyContainer = new(Container)
func FromDocker(c *types.Container, dockerHost string) (res *Container) { func FromDocker(c *types.Container, dockerHost string) (res *Container) {
isExplicit := c.Labels[LabelAliases] != "" isExplicit := c.Labels[LabelAliases] != ""
helper := containerHelper{c} helper := containerHelper{c}

View file

@ -2,7 +2,6 @@ package idlewatcher
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -73,15 +72,15 @@ func (w *Waker) Uptime() time.Duration {
} }
func (w *Waker) MarshalJSON() ([]byte, error) { func (w *Waker) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{ return (&health.JSONRepresentation{
"name": w.Name(), Name: w.Name(),
"url": w.URL, Status: w.Status(),
"status": w.Status(), Config: &health.HealthCheckConfig{
"config": health.HealthCheckConfig{
Interval: w.IdleTimeout, Interval: w.IdleTimeout,
Timeout: w.WakeTimeout, Timeout: w.WakeTimeout,
}, },
}) URL: w.URL,
}).MarshalJSON()
} }
/* End of HealthMonitor interface */ /* End of HealthMonitor interface */
@ -89,6 +88,10 @@ func (w *Waker) MarshalJSON() ([]byte, error) {
func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer() w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is ready // pass through if container is ready
if w.ready.Load() { if w.ready.Load() {
return true return true
@ -115,6 +118,16 @@ func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool)
return return
} }
select {
case <-w.task.Context().Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
// wake the container and reset idle timer // wake the container and reset idle timer
// also wait for another wake request // also wait for another wake request
w.wakeCh <- struct{}{} w.wakeCh <- struct{}{}
@ -169,3 +182,8 @@ func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
} }
// static HealthMonitor interface check
func (w *Waker) _() health.HealthMonitor {
return w
}

View file

@ -8,10 +8,12 @@ import (
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy" P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields" PT "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
) )
@ -29,9 +31,10 @@ type (
wakeDone chan E.NestedError wakeDone chan E.NestedError
ticker *time.Ticker ticker *time.Ticker
ctx context.Context task common.Task
cancel context.CancelFunc cancel context.CancelFunc
refCount *sync.WaitGroup
refCount *U.RefCount
l logrus.FieldLogger l logrus.FieldLogger
} }
@ -42,17 +45,11 @@ type (
) )
var ( var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = F.NewMapOf[string, *Watcher]() watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex watcherMapMu sync.Mutex
portHistoryMap = F.NewMapOf[PT.Alias, string]() portHistoryMap = F.NewMapOf[PT.Alias, string]()
newWatcherCh = make(chan *Watcher)
logger = logrus.WithField("module", "idle_watcher") logger = logrus.WithField("module", "idle_watcher")
) )
@ -73,7 +70,7 @@ func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
} }
if w, ok := watcherMap.Load(key); ok { if w, ok := watcherMap.Load(key); ok {
w.refCount.Add(1) w.refCount.Add()
w.ReverseProxyEntry = entry w.ReverseProxyEntry = entry
return w, nil return w, nil
} }
@ -86,83 +83,51 @@ func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
w := &Watcher{ w := &Watcher{
ReverseProxyEntry: entry, ReverseProxyEntry: entry,
client: client, client: client,
refCount: &sync.WaitGroup{}, refCount: U.NewRefCounter(),
wakeCh: make(chan struct{}, 1), wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError), wakeDone: make(chan E.NestedError),
ticker: time.NewTicker(entry.IdleTimeout), ticker: time.NewTicker(entry.IdleTimeout),
l: logger.WithField("container", entry.ContainerName), l: logger.WithField("container", entry.ContainerName),
} }
w.refCount.Add(1) w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias)
w.stopByMethod = w.getStopCallback() w.stopByMethod = w.getStopCallback()
watcherMap.Store(key, w) watcherMap.Store(key, w)
go func() { go w.watchUntilCancel()
newWatcherCh <- w
}()
return w, nil return w, nil
} }
func (w *Watcher) Unregister() { func (w *Watcher) Unregister() {
w.refCount.Add(-1) w.refCount.Sub()
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watchUntilCancel()
w.refCount.Wait() // wait for 0 ref count
watcherMap.Delete(w.ContainerID)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
} }
func (w *Watcher) containerStop() error { func (w *Watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{ return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal), Signal: string(w.StopSignal),
Timeout: &w.StopTimeout, Timeout: &w.StopTimeout,
}) })
} }
func (w *Watcher) containerPause() error { func (w *Watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerID) return w.client.ContainerPause(w.task.Context(), w.ContainerID)
} }
func (w *Watcher) containerKill() error { func (w *Watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal)) return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal))
} }
func (w *Watcher) containerUnpause() error { func (w *Watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerID) return w.client.ContainerUnpause(w.task.Context(), w.ContainerID)
} }
func (w *Watcher) containerStart() error { func (w *Watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{}) return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{})
} }
func (w *Watcher) containerStatus() (string, E.NestedError) { func (w *Watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerID) json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID)
if err != nil { if err != nil {
return "", E.FailWith("inspect container", err) return "", E.FailWith("inspect container", err)
} }
@ -221,12 +186,8 @@ func (w *Watcher) resetIdleTimer() {
} }
func (w *Watcher) watchUntilCancel() { func (w *Watcher) watchUntilCancel() {
defer close(w.wakeCh)
w.ctx, w.cancel = context.WithCancel(mainLoopCtx)
dockerWatcher := W.NewDockerWatcherWithClient(w.client) dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{ dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter( Filters: W.NewDockerFilter(
W.DockerFilterContainer, W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID), W.DockerrFilterContainer(w.ContainerID),
@ -238,13 +199,23 @@ func (w *Watcher) watchUntilCancel() {
W.DockerFilterUnpause, W.DockerFilterUnpause,
), ),
}) })
defer w.ticker.Stop()
defer w.client.Close() defer func() {
w.ticker.Stop()
w.client.Close()
close(w.wakeDone)
close(w.wakeCh)
watcherMap.Delete(w.ContainerID)
w.task.Finished()
}()
for { for {
select { select {
case <-w.ctx.Done(): case <-w.task.Context().Done():
w.l.Debug("stopped") w.l.Debug("stopped by context done")
return
case <-w.refCount.Zero():
w.l.Debug("stopped by zero ref count")
return return
case err := <-dockerEventErrCh: case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) { if err != nil && err.IsNot(context.Canceled) {

View file

@ -7,6 +7,17 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) {
client, err := ConnectClient(dockerHost)
defer client.Close()
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
}
return client.Inspect(containerID)
}
func (c Client) Inspect(containerID string) (*Container, E.NestedError) { func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()

View file

@ -46,7 +46,18 @@ func TestBuilderNested(t *testing.T) {
- invalid Inner: "2" - invalid Inner: "2"
- Action 2 failed: - Action 2 failed:
- invalid Inner: "3"`) - invalid Inner: "3"`)
if got != expected1 && got != expected2 { ExpectEqualAny(t, got, []string{expected1, expected2})
t.Errorf("expected \n%s, got \n%s", expected1, got) }
}
func TestBuilderTo(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Addf("abcd")
var err NestedError
eb.To(&err)
got := err.String()
expected := (`error occurred:
- abcd`)
ExpectEqual(t, got, expected)
} }

View file

@ -6,8 +6,8 @@ import (
"time" "time"
"github.com/go-acme/lego/v4/log" "github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/watcher/health"
) )
// TODO: stats of each server. // TODO: stats of each server.
@ -41,13 +41,14 @@ type (
const maxWeight weightType = 100 const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer { func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{Config: cfg, pool: servers{}} lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)}
mode := cfg.Mode lb.UpdateConfigIfNeeded(cfg)
if !cfg.Mode.ValidateUpdate() { return lb
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %s", cfg.Link, mode, cfg.Mode) }
}
switch mode { func (lb *LoadBalancer) updateImpl() {
case RoundRobin: switch lb.Mode {
case Unset, RoundRobin:
lb.impl = lb.newRoundRobin() lb.impl = lb.newRoundRobin()
case LeastConn: case LeastConn:
lb.impl = lb.newLeastConn() lb.impl = lb.newLeastConn()
@ -56,7 +57,34 @@ func New(cfg *Config) *LoadBalancer {
default: // should happen in test only default: // should happen in test only
lb.impl = lb.newRoundRobin() lb.impl = lb.newRoundRobin()
} }
return lb for _, srv := range lb.pool {
lb.impl.OnAddServer(srv)
}
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if cfg != nil {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.Link = cfg.Link
if lb.Mode == Unset && cfg.Mode != Unset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
lb.Options = cfg.Options
}
}
if lb.impl == nil {
lb.updateImpl()
}
} }
func (lb *LoadBalancer) AddServer(srv *Server) { func (lb *LoadBalancer) AddServer(srv *Server) {
@ -66,6 +94,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.pool = append(lb.pool, srv) lb.pool = append(lb.pool, srv)
lb.sumWeight += srv.Weight lb.sumWeight += srv.Weight
lb.Rebalance()
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool)) logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool))
} }
@ -74,6 +103,8 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
lb.sumWeight -= srv.Weight
lb.Rebalance()
lb.impl.OnRemoveServer(srv) lb.impl.OnRemoveServer(srv)
for i, s := range lb.pool { for i, s := range lb.pool {
@ -87,7 +118,6 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
return return
} }
lb.Rebalance()
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool)) logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool))
} }
@ -152,20 +182,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
} }
func (lb *LoadBalancer) Start() { func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
for _, s := range lb.pool {
msg.Addf("%s: %d", s.Name, s.Weight)
}
lb.Rebalance()
inner := E.NewBuilder("after rebalancing")
for _, s := range lb.pool {
inner.Addf("%s: %d", s.Name, s.Weight)
}
msg.Addf("%s", inner)
logger.Warn(msg)
}
if lb.sumWeight != 0 { if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet") log.Warnf("weighted mode not supported yet")
} }
@ -186,6 +202,45 @@ func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime) return time.Since(lb.startTime)
} }
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
for _, v := range lb.pool {
extra[v.Name] = v.healthMon
}
return (&health.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
Started: lb.startTime,
Uptime: lb.Uptime(),
Extra: map[string]any{
"config": lb.Config,
"pool": extra,
},
}).MarshalJSON()
}
// Name implements health.HealthMonitor.
func (lb *LoadBalancer) Name() string {
return lb.Link
}
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status {
if len(lb.pool) == 0 {
return health.StatusUnknown
}
if len(lb.availServers()) == 0 {
return health.StatusUnhealthy
}
return health.StatusHealthy
}
// String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() servers { func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock() lb.poolMu.Lock()
defer lb.poolMu.Unlock() defer lb.poolMu.Unlock()
@ -199,3 +254,8 @@ func (lb *LoadBalancer) availServers() servers {
} }
return avail return avail
} }
// static HealthMonitor interface check
func (lb *LoadBalancer) _() health.HealthMonitor {
return lb
}

View file

@ -7,6 +7,7 @@ import (
type Mode string type Mode string
const ( const (
Unset Mode = ""
RoundRobin Mode = "roundrobin" RoundRobin Mode = "roundrobin"
LeastConn Mode = "leastconn" LeastConn Mode = "leastconn"
IPHash Mode = "iphash" IPHash Mode = "iphash"
@ -14,7 +15,9 @@ const (
func (mode *Mode) ValidateUpdate() bool { func (mode *Mode) ValidateUpdate() bool {
switch U.ToLowerNoSnake(string(*mode)) { switch U.ToLowerNoSnake(string(*mode)) {
case "", string(RoundRobin): case "":
return true
case string(RoundRobin):
*mode = RoundRobin *mode = RoundRobin
return true return true
case string(LeastConn): case string(LeastConn):

View file

@ -16,32 +16,36 @@ import (
type ( type (
ReverseProxyEntry struct { // real model after validation ReverseProxyEntry struct { // real model after validation
Alias T.Alias `json:"alias"` Raw *types.RawEntry `json:"raw"`
Scheme T.Scheme `json:"scheme"`
URL net.URL `json:"url"` Alias T.Alias `json:"alias,omitempty"`
Scheme T.Scheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"` NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns T.PathPatterns `json:"path_patterns"` PathPatterns T.PathPatterns `json:"path_patterns,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck"` HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"` LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
Middlewares D.NestedLabelMap `json:"middlewares,omitempty"` Middlewares D.NestedLabelMap `json:"middlewares,omitempty"`
/* Docker only */ /* Docker only */
IdleTimeout time.Duration `json:"idle_timeout"` IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout"` WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopMethod T.StopMethod `json:"stop_method"` StopMethod T.StopMethod `json:"stop_method,omitempty"`
StopTimeout int `json:"stop_timeout"` StopTimeout int `json:"stop_timeout,omitempty"`
StopSignal T.Signal `json:"stop_signal,omitempty"` StopSignal T.Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host"` DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name"` ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id"` ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running"` ContainerRunning bool `json:"container_running,omitempty"`
} }
StreamEntry struct { StreamEntry struct {
Alias T.Alias `json:"alias"` Raw *types.RawEntry `json:"raw"`
Scheme T.StreamScheme `json:"scheme"`
Host T.Host `json:"host"` Alias T.Alias `json:"alias,omitempty"`
Port T.StreamPort `json:"port"` Scheme T.StreamScheme `json:"scheme,omitempty"`
Healthcheck *health.HealthCheckConfig `json:"healthcheck"` Host T.Host `json:"host,omitempty"`
Port T.StreamPort `json:"port,omitempty"`
Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
} }
) )
@ -88,6 +92,10 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry { func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration var stopTimeOut time.Duration
cont := m.Container
if cont == nil {
cont = D.DummyContainer
}
host, err := T.ValidateHost(m.Host) host, err := T.ValidateHost(m.Host)
b.Add(err) b.Add(err)
@ -101,21 +109,21 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err) b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout) idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout)
b.Add(err) b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout) wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout)
b.Add(err) b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod) stopMethod, err := T.ValidateStopMethod(cont.StopMethod)
b.Add(err) b.Add(err)
if stopMethod == T.StopMethodStop { if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout) stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout)
b.Add(err) b.Add(err)
} }
stopSignal, err := T.ValidateSignal(m.StopSignal) stopSignal, err := T.ValidateSignal(cont.StopSignal)
b.Add(err) b.Add(err)
if err != nil { if err != nil {
@ -123,6 +131,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
} }
return &ReverseProxyEntry{ return &ReverseProxyEntry{
Raw: m,
Alias: T.NewAlias(m.Alias), Alias: T.NewAlias(m.Alias),
Scheme: s, Scheme: s,
URL: net.NewURL(url), URL: net.NewURL(url),
@ -136,10 +145,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
StopMethod: stopMethod, StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal, StopSignal: stopSignal,
DockerHost: m.DockerHost, DockerHost: cont.DockerHost,
ContainerName: m.ContainerName, ContainerName: cont.ContainerName,
ContainerID: m.ContainerID, ContainerID: cont.ContainerID,
ContainerRunning: m.Running, ContainerRunning: cont.Running,
} }
} }
@ -158,6 +167,7 @@ func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry {
} }
return &StreamEntry{ return &StreamEntry{
Raw: m,
Alias: T.NewAlias(m.Alias), Alias: T.NewAlias(m.Alias),
Scheme: *scheme, Scheme: *scheme,
Host: host, Host: host,

View file

@ -32,7 +32,7 @@ func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
} }
func (s StreamScheme) String() string { func (s StreamScheme) String() string {
return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme) return fmt.Sprintf("%s -> %s", s.ListeningScheme, s.ProxyScheme)
} }
// IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal. // IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal.

View file

@ -72,7 +72,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
} }
entries.RangeAll(func(_ string, e *types.RawEntry) { entries.RangeAll(func(_ string, e *types.RawEntry) {
e.DockerHost = p.dockerHost e.Container.DockerHost = p.dockerHost
}) })
routes, err = R.FromEntries(entries) routes, err = R.FromEntries(entries)
@ -88,7 +88,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
strings.HasSuffix(container.ContainerName, "-old") strings.HasSuffix(container.ContainerName, "-old")
} }
func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) { func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) {
switch event.Action { switch event.Action {
case events.ActionContainerStart, events.ActionContainerStop: case events.ActionContainerStart, events.ActionContainerStop:
break break
@ -98,75 +98,66 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
b := E.NewBuilder("event %s error", event) b := E.NewBuilder("event %s error", event)
defer b.To(&res.err) defer b.To(&res.err)
routes.RangeAll(func(k string, v *R.Route) { matches := R.NewRoutes()
if v.Entry.ContainerID == event.ActorID || oldRoutes.RangeAllParallel(func(k string, v *R.Route) {
v.Entry.ContainerName == event.ActorName { if v.Entry.Container.ContainerID == event.ActorID ||
v.Entry.Container.ContainerName == event.ActorName {
matches.Store(k, v)
}
})
var newRoutes R.Routes
var err E.NestedError
if matches.Size() == 0 { // id & container name changed
matches = oldRoutes
newRoutes, err = p.LoadRoutesImpl()
b.Add(err)
} else {
cont, err := D.Inspect(p.dockerHost, event.ActorID)
if err != nil {
b.Add(E.FailWith("inspect container", err))
return
}
if p.shouldIgnore(cont) {
// stop all old routes
matches.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
})
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
newRoutes, err = R.FromEntries(entries)
b.Add(err)
}
matches.RangeAll(func(k string, v *R.Route) {
if !newRoutes.Has(k) && !oldRoutes.Has(k) {
b.Add(v.Stop()) b.Add(v.Stop())
routes.Delete(k) matches.Delete(k)
res.nRemoved++ res.nRemoved++
} }
}) })
if res.nRemoved == 0 { // id & container name changed newRoutes.RangeAll(func(alias string, newRoute *R.Route) {
// load all routes (rescan) oldRoute, exists := oldRoutes.Load(alias)
routesNew, err := p.LoadRoutesImpl() if exists {
routesOld := routes if err := oldRoute.Stop(); err != nil {
if routesNew.Size() == 0 {
b.Add(E.FailWith("rescan routes", err))
return
}
routesNew.Range(func(k string, v *R.Route) bool {
if !routesOld.Has(k) {
routesOld.Store(k, v)
b.Add(v.Start())
res.nAdded++
return false
}
return true
})
routesOld.Range(func(k string, v *R.Route) bool {
if !routesNew.Has(k) {
b.Add(v.Stop())
routesOld.Delete(k)
res.nRemoved++
return false
}
return true
})
return
}
client, err := D.ConnectClient(p.dockerHost)
if err != nil {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if err != nil {
b.Add(E.FailWith("inspect container", err))
return
}
if p.shouldIgnore(cont) {
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *types.RawEntry) {
if routes.Has(alias) {
b.Add(E.Duplicated("alias", alias))
} else {
if route, err := R.NewRoute(entry); err != nil {
b.Add(err) b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
} }
} }
oldRoutes.Store(alias, newRoute)
if err := newRoute.Start(); err != nil {
b.Add(err)
}
if exists {
res.nReloaded++
} else {
res.nAdded++
}
}) })
return return

View file

@ -88,20 +88,20 @@ func TestApplyLabelWildcard(t *testing.T) {
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect) ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0) ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.IdleTimeout, common.IdleTimeoutDefault) ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.IdleTimeout, common.IdleTimeoutDefault) ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.StopTimeout, common.StopTimeoutDefault) ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.StopTimeout, common.StopTimeoutDefault) ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, a.StopMethod, common.StopMethodDefault) ExpectEqual(t, a.Container.StopMethod, common.StopMethodDefault)
ExpectEqual(t, b.StopMethod, common.StopMethodDefault) ExpectEqual(t, b.Container.StopMethod, common.StopMethodDefault)
ExpectEqual(t, a.WakeTimeout, common.WakeTimeoutDefault) ExpectEqual(t, a.Container.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, b.WakeTimeout, common.WakeTimeoutDefault) ExpectEqual(t, b.Container.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, a.StopSignal, "SIGTERM") ExpectEqual(t, a.Container.StopSignal, "SIGTERM")
ExpectEqual(t, b.StopSignal, "SIGTERM") ExpectEqual(t, b.Container.StopSignal, "SIGTERM")
} }
func TestApplyLabelWithAlias(t *testing.T) { func TestApplyLabelWithAlias(t *testing.T) {
@ -186,16 +186,16 @@ func TestPublicIPLocalhost(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost) c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost)
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectEqual(t, raw.PublicIP, "127.0.0.1") ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1")
ExpectEqual(t, raw.Host, raw.PublicIP) ExpectEqual(t, raw.Host, raw.Container.PublicIP)
} }
func TestPublicIPRemote(t *testing.T) { func TestPublicIPRemote(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375") c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectEqual(t, raw.PublicIP, "1.2.3.4") ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
ExpectEqual(t, raw.Host, raw.PublicIP) ExpectEqual(t, raw.Host, raw.Container.PublicIP)
} }
func TestPrivateIPLocalhost(t *testing.T) { func TestPrivateIPLocalhost(t *testing.T) {
@ -211,8 +211,8 @@ func TestPrivateIPLocalhost(t *testing.T) {
}, client.DefaultDockerHost) }, client.DefaultDockerHost)
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectEqual(t, raw.PrivateIP, "172.17.0.123") ExpectEqual(t, raw.Container.PrivateIP, "172.17.0.123")
ExpectEqual(t, raw.Host, raw.PrivateIP) ExpectEqual(t, raw.Host, raw.Container.PrivateIP)
} }
func TestPrivateIPRemote(t *testing.T) { func TestPrivateIPRemote(t *testing.T) {
@ -228,9 +228,9 @@ func TestPrivateIPRemote(t *testing.T) {
}, "tcp://1.2.3.4:2375") }, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectEqual(t, raw.PrivateIP, "") ExpectEqual(t, raw.Container.PrivateIP, "")
ExpectEqual(t, raw.PublicIP, "1.2.3.4") ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
ExpectEqual(t, raw.Host, raw.PublicIP) ExpectEqual(t, raw.Host, raw.Container.PublicIP)
} }
func TestStreamDefaultValues(t *testing.T) { func TestStreamDefaultValues(t *testing.T) {

View file

@ -5,6 +5,7 @@ import (
"path" "path"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
@ -19,7 +20,7 @@ type (
routes R.Routes routes R.Routes
watcher W.Watcher watcher W.Watcher
watcherCtx context.Context watcherTask common.Task
watcherCancel context.CancelFunc watcherCancel context.CancelFunc
l *logrus.Entry l *logrus.Entry
@ -38,9 +39,10 @@ type (
Type ProviderType `json:"type"` Type ProviderType `json:"type"`
} }
EventResult struct { EventResult struct {
nRemoved int nAdded int
nAdded int nRemoved int
err E.NestedError nReloaded int
err E.NestedError
} }
) )
@ -129,6 +131,7 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) {
p.routes.RangeAllParallel(func(alias string, r *R.Route) { p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Stop().Subject(r)) errors.Add(r.Stop().Subject(r))
}) })
p.routes.Clear()
return return
} }
@ -175,27 +178,21 @@ func (p *Provider) Statistics() ProviderStats {
} }
func (p *Provider) watchEvents() { func (p *Provider) watchEvents() {
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name)
events, errs := p.watcher.Events(p.watcherCtx) defer p.watcherTask.Finished()
events, errs := p.watcher.Events(p.watcherTask.Context())
l := p.l.WithField("module", "watcher") l := p.l.WithField("module", "watcher")
for { for {
select { select {
case <-p.watcherCtx.Done(): case <-p.watcherTask.Context().Done():
return return
case event := <-events: case event := <-events:
res := p.OnEvent(event, p.routes) res := p.OnEvent(event, p.routes)
l.Infof("%s event %q", event.Type, event) if res.nAdded+res.nRemoved+res.nReloaded > 0 {
if res.nAdded > 0 || res.nRemoved > 0 { l.Infof("%s event %q", event.Type, event)
n := res.nAdded - res.nRemoved l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded)
switch {
case n == 0:
l.Infof("%d route(s) reloaded", res.nAdded)
case n > 0:
l.Infof("%d route(s) added", n)
default:
l.Infof("%d route(s) removed", -n)
}
} }
if res.err != nil { if res.err != nil {
l.Error(res.err) l.Error(res.err)

View file

@ -18,6 +18,7 @@ import (
url "github.com/yusing/go-proxy/internal/net/types" url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy" P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields" PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -26,12 +27,12 @@ type (
HTTPRoute struct { HTTPRoute struct {
*P.ReverseProxyEntry `json:"entry"` *P.ReverseProxyEntry `json:"entry"`
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"` HealthMon health.HealthMonitor `json:"health,omitempty"`
HealthMon health.HealthMonitor `json:"health"`
server *loadbalancer.Server loadBalancer *loadbalancer.LoadBalancer
handler http.Handler server *loadbalancer.Server
rp *gphttp.ReverseProxy handler http.Handler
rp *gphttp.ReverseProxy
} }
SubdomainKey = PT.Alias SubdomainKey = PT.Alias
@ -102,10 +103,6 @@ func (r *HTTPRoute) URL() url.URL {
} }
func (r *HTTPRoute) Start() E.NestedError { func (r *HTTPRoute) Start() E.NestedError {
if r.handler != nil {
return nil
}
if r.ShouldNotServe() { if r.ShouldNotServe() {
return nil return nil
} }
@ -113,6 +110,10 @@ func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock() httpRoutesMu.Lock()
defer httpRoutesMu.Unlock() defer httpRoutesMu.Unlock()
if r.handler != nil {
return nil
}
if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) { if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.HealthCheck.Disabled = true r.HealthCheck.Disabled = true
@ -129,15 +130,23 @@ func (r *HTTPRoute) Start() E.NestedError {
r.HealthMon = waker r.HealthMon = waker
case !r.HealthCheck.Disabled: case !r.HealthCheck.Disabled:
r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask("Reverse proxy "+r.String()), r.URL(), r.HealthCheck) r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask("Reverse proxy "+r.String()), r.URL(), r.HealthCheck)
fallthrough }
case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
r.handler = ReverseProxyHandler{r.rp} if r.handler == nil {
default: switch {
mux := http.NewServeMux() case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
for _, p := range r.PathPatterns { r.handler = ReverseProxyHandler{r.rp}
mux.HandleFunc(string(p), r.rp.ServeHTTP) default:
mux := http.NewServeMux()
for _, p := range r.PathPatterns {
mux.HandleFunc(string(p), r.rp.ServeHTTP)
}
r.handler = mux
} }
r.handler = mux }
if r.HealthMon != nil {
r.HealthMon.Start()
} }
if r.UseLoadBalance() { if r.UseLoadBalance() {
@ -146,9 +155,6 @@ func (r *HTTPRoute) Start() E.NestedError {
httpRoutes.Store(string(r.Alias), r) httpRoutes.Store(string(r.Alias), r)
} }
if r.HealthMon != nil {
r.HealthMon.Start()
}
return nil return nil
} }
@ -160,7 +166,7 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
httpRoutesMu.Lock() httpRoutesMu.Lock()
defer httpRoutesMu.Unlock() defer httpRoutesMu.Unlock()
if r.LoadBalancer != nil { if r.loadBalancer != nil {
r.removeFromLoadBalancer() r.removeFromLoadBalancer()
} else { } else {
httpRoutes.Delete(string(r.Alias)) httpRoutes.Delete(string(r.Alias))
@ -184,29 +190,40 @@ func (r *HTTPRoute) addToLoadBalancer() {
var lb *loadbalancer.LoadBalancer var lb *loadbalancer.LoadBalancer
linked, ok := httpRoutes.Load(r.LoadBalance.Link) linked, ok := httpRoutes.Load(r.LoadBalance.Link)
if ok { if ok {
lb = linked.LoadBalancer lb = linked.loadBalancer
lb.UpdateConfigIfNeeded(r.LoadBalance)
if linked.Raw.Homepage == nil && r.Raw.Homepage != nil {
linked.Raw.Homepage = r.Raw.Homepage
}
} else { } else {
lb = loadbalancer.New(r.LoadBalance) lb = loadbalancer.New(r.LoadBalance)
lb.Start() lb.Start()
linked = &HTTPRoute{ linked = &HTTPRoute{
LoadBalancer: lb, ReverseProxyEntry: &P.ReverseProxyEntry{
Raw: &types.RawEntry{
Homepage: r.Raw.Homepage,
},
Alias: PT.Alias(lb.Link),
},
HealthMon: lb,
loadBalancer: lb,
handler: lb, handler: lb,
} }
httpRoutes.Store(r.LoadBalance.Link, linked) httpRoutes.Store(r.LoadBalance.Link, linked)
} }
r.LoadBalancer = lb r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon) r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server) lb.AddServer(r.server)
} }
func (r *HTTPRoute) removeFromLoadBalancer() { func (r *HTTPRoute) removeFromLoadBalancer() {
r.LoadBalancer.RemoveServer(r.server) r.loadBalancer.RemoveServer(r.server)
if r.LoadBalancer.IsEmpty() { if r.loadBalancer.IsEmpty() {
httpRoutes.Delete(r.LoadBalance.Link) httpRoutes.Delete(r.LoadBalance.Link)
logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link) logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link)
} }
r.server = nil r.server = nil
r.LoadBalancer = nil r.loadBalancer = nil
} }
func ProxyHandler(w http.ResponseWriter, r *http.Request) { func ProxyHandler(w http.ResponseWriter, r *http.Request) {

View file

@ -1,6 +1,7 @@
package route package route
import ( import (
"github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types" url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy" P "github.com/yusing/go-proxy/internal/proxy"
@ -36,6 +37,13 @@ const (
// function alias. // function alias.
var NewRoutes = F.NewMap[Routes] var NewRoutes = F.NewMap[Routes]
func (rt *Route) Container() *docker.Container {
if rt.Entry.Container == nil {
return docker.DummyContainer
}
return rt.Entry.Container
}
func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
entry, err := P.ValidateEntry(en) entry, err := P.ValidateEntry(en)
if err != nil { if err != nil {

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -24,12 +25,11 @@ type StreamRoute struct {
url url.URL url url.URL
wg sync.WaitGroup
task common.Task task common.Task
cancel context.CancelFunc cancel context.CancelFunc
done chan struct{}
connCh chan any l logrus.FieldLogger
l logrus.FieldLogger
mu sync.Mutex mu sync.Mutex
} }
@ -61,7 +61,6 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
base := &StreamRoute{ base := &StreamRoute{
StreamEntry: entry, StreamEntry: entry,
url: url, url: url,
connCh: make(chan any, 100),
} }
if entry.Scheme.ListeningScheme.IsTCP() { if entry.Scheme.ListeningScheme.IsTCP() {
base.StreamImpl = NewTCPRoute(base) base.StreamImpl = NewTCPRoute(base)
@ -73,7 +72,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
} }
func (r *StreamRoute) String() string { func (r *StreamRoute) String() string {
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias) return fmt.Sprintf("stream %s", r.Alias)
} }
func (r *StreamRoute) URL() url.URL { func (r *StreamRoute) URL() url.URL {
@ -88,14 +87,12 @@ func (r *StreamRoute) Start() E.NestedError {
return nil return nil
} }
r.task, r.cancel = common.NewTaskWithCancel(r.String()) r.task, r.cancel = common.NewTaskWithCancel(r.String())
r.wg.Wait()
if err := r.Setup(); err != nil { if err := r.Setup(); err != nil {
return E.FailWith("setup", err) return E.FailWith("setup", err)
} }
r.done = make(chan struct{})
r.l.Infof("listening on port %d", r.Port.ListeningPort) r.l.Infof("listening on port %d", r.Port.ListeningPort)
r.wg.Add(2)
go r.acceptConnections() go r.acceptConnections()
go r.handleConnections()
if !r.Healthcheck.Disabled { if !r.Healthcheck.Disabled {
r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck) r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck)
r.HealthMon.Start() r.HealthMon.Start()
@ -122,11 +119,7 @@ func (r *StreamRoute) Stop() E.NestedError {
r.cancel() r.cancel()
r.CloseListeners() r.CloseListeners()
r.wg.Wait() <-r.done
r.task.Finished()
r.task, r.cancel = nil, nil
return nil return nil
} }
@ -135,41 +128,45 @@ func (r *StreamRoute) Started() bool {
} }
func (r *StreamRoute) acceptConnections() { func (r *StreamRoute) acceptConnections() {
defer r.wg.Done() var connWg sync.WaitGroup
task := r.task.Subtask("%s accept connections", r.String())
defer func() {
connWg.Wait()
task.Finished()
r.task.Finished()
r.task, r.cancel = nil, nil
close(r.done)
r.done = nil
}()
for { for {
select { select {
case <-r.task.Context().Done(): case <-task.Context().Done():
return return
default: default:
conn, err := r.Accept() conn, err := r.Accept()
if err != nil { if err != nil {
select { select {
case <-r.task.Context().Done(): case <-task.Context().Done():
return return
default: default:
r.l.Error(err) var nErr *net.OpError
ok := errors.As(err, &nErr)
if !(ok && nErr.Timeout()) {
r.l.Error(err)
}
continue continue
} }
} }
r.connCh <- conn connWg.Add(1)
}
}
}
func (r *StreamRoute) handleConnections() {
defer r.wg.Done()
for {
select {
case <-r.task.Context().Done():
return
case conn := <-r.connCh:
go func() { go func() {
err := r.Handle(conn) err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err) r.l.Error(err)
} }
connWg.Done()
}() }()
} }
} }

View file

@ -4,31 +4,25 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
T "github.com/yusing/go-proxy/internal/proxy/fields" T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
) )
const tcpDialTimeout = 5 * time.Second const tcpDialTimeout = 5 * time.Second
type ( type (
Pipes []U.BidirectionalPipe TCPConnMap = F.Map[net.Conn, struct{}]
TCPRoute struct {
TCPRoute struct {
*StreamRoute *StreamRoute
listener net.Listener listener *net.TCPListener
pipe Pipes
mu sync.Mutex
} }
) )
func NewTCPRoute(base *StreamRoute) StreamImpl { func NewTCPRoute(base *StreamRoute) StreamImpl {
return &TCPRoute{ return &TCPRoute{StreamRoute: base}
StreamRoute: base,
pipe: make(Pipes, 0),
}
} }
func (route *TCPRoute) Setup() error { func (route *TCPRoute) Setup() error {
@ -38,11 +32,12 @@ func (route *TCPRoute) Setup() error {
} }
//! this read the allocated port from original ':0' //! this read the allocated port from original ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in route.listener = in.(*net.TCPListener)
return nil return nil
} }
func (route *TCPRoute) Accept() (any, error) { func (route *TCPRoute) Accept() (any, error) {
route.listener.SetDeadline(time.Now().Add(time.Second))
return route.listener.Accept() return route.listener.Accept()
} }
@ -50,24 +45,23 @@ func (route *TCPRoute) Handle(c any) error {
clientConn := c.(net.Conn) clientConn := c.(net.Conn)
defer clientConn.Close() defer clientConn.Close()
go func() {
<-route.task.Context().Done()
clientConn.Close()
}()
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
defer cancel()
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{} dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
cancel()
if err != nil { if err != nil {
return err return err
} }
route.mu.Lock()
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start() return pipe.Start()
} }

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields" T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
@ -67,6 +68,7 @@ func (route *UDPRoute) Accept() (any, error) {
in := route.listeningConn in := route.listeningConn
buffer := make([]byte, udpBufferSize) buffer := make([]byte, udpBufferSize)
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second))
nRead, srcAddr, err := in.ReadFromUDP(buffer) nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -84,7 +84,7 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider, CertProvider: opt.CertProvider,
http: httpSer, http: httpSer,
https: httpsSer, https: httpsSer,
task: common.GlobalTask("Server " + opt.Name), task: common.GlobalTask(opt.Name + " server"),
} }
} }
@ -133,13 +133,11 @@ func (s *Server) stop() {
if s.http != nil && s.httpStarted { if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(ctx)) s.handleErr("http", s.http.Shutdown(ctx))
s.httpStarted = false s.httpStarted = false
logger.Debugf("HTTP server %q stopped", s.Name)
} }
if s.https != nil && s.httpsStarted { if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(ctx)) s.handleErr("https", s.https.Shutdown(ctx))
s.httpsStarted = false s.httpsStarted = false
logger.Debugf("HTTPS server %q stopped", s.Name)
} }
} }
@ -152,7 +150,7 @@ func (s *Server) handleErr(scheme string, err error) {
case err == nil, errors.Is(err, http.ErrServerClosed): case err == nil, errors.Is(err, http.ErrServerClosed):
return return
default: default:
logrus.Fatalf("failed to start %s %s server: %s", scheme, s.Name, err) logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err)
} }
} }

View file

@ -22,9 +22,9 @@ type (
// raw entry object before validation // raw entry object before validation
// loaded from docker labels or yaml file // loaded from docker labels or yaml file
Alias string `json:"-" yaml:"-"` Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme" yaml:"scheme"` Scheme string `json:"scheme,omitempty" yaml:"scheme"`
Host string `json:"host" yaml:"host"` Host string `json:"host,omitempty" yaml:"host"`
Port string `json:"port" yaml:"port"` Port string `json:"port,omitempty" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"` HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
@ -33,7 +33,7 @@ type (
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
/* Docker only */ /* Docker only */
*docker.Container `json:"container" yaml:"-"` Container *docker.Container `json:"container,omitempty" yaml:"-"`
} }
RawEntries = F.Map[string, *RawEntry] RawEntries = F.Map[string, *RawEntry]
@ -43,16 +43,17 @@ var NewProxyEntries = F.NewMapOf[string, *RawEntry]
func (e *RawEntry) FillMissingFields() { func (e *RawEntry) FillMissingFields() {
isDocker := e.Container != nil isDocker := e.Container != nil
cont := e.Container
if !isDocker { if !isDocker {
e.Container = &docker.Container{} cont = docker.DummyContainer
} }
if e.Host == "" { if e.Host == "" {
switch { switch {
case e.PrivateIP != "": case cont.PrivateIP != "":
e.Host = e.PrivateIP e.Host = cont.PrivateIP
case e.PublicIP != "": case cont.PublicIP != "":
e.Host = e.PublicIP e.Host = cont.PublicIP
case !isDocker: case !isDocker:
e.Host = "localhost" e.Host = "localhost"
} }
@ -60,14 +61,14 @@ func (e *RawEntry) FillMissingFields() {
lp, pp, extra := e.splitPorts() lp, pp, extra := e.splitPorts()
if port, ok := common.ServiceNamePortMapTCP[e.ImageName]; ok { if port, ok := common.ServiceNamePortMapTCP[cont.ImageName]; ok {
if pp == "" { if pp == "" {
pp = strconv.Itoa(port) pp = strconv.Itoa(port)
} }
if e.Scheme == "" { if e.Scheme == "" {
e.Scheme = "tcp" e.Scheme = "tcp"
} }
} else if port, ok := common.ImageNamePortMap[e.ImageName]; ok { } else if port, ok := common.ImageNamePortMap[cont.ImageName]; ok {
if pp == "" { if pp == "" {
pp = strconv.Itoa(port) pp = strconv.Itoa(port)
} }
@ -77,9 +78,9 @@ func (e *RawEntry) FillMissingFields() {
} else if pp == "" && e.Scheme == "https" { } else if pp == "" && e.Scheme == "https" {
pp = "443" pp = "443"
} else if pp == "" { } else if pp == "" {
if p := lowestPort(e.PrivatePortMapping); p != "" { if p := lowestPort(cont.PrivatePortMapping); p != "" {
pp = p pp = p
} else if p := lowestPort(e.PublicPortMapping); p != "" { } else if p := lowestPort(cont.PublicPortMapping); p != "" {
pp = p pp = p
} else if !isDocker { } else if !isDocker {
pp = "80" pp = "80"
@ -89,23 +90,23 @@ func (e *RawEntry) FillMissingFields() {
} }
// replace private port with public port if using public IP. // replace private port with public port if using public IP.
if e.Host == e.PublicIP { if e.Host == cont.PublicIP {
if p, ok := e.PrivatePortMapping[pp]; ok { if p, ok := cont.PrivatePortMapping[pp]; ok {
pp = U.PortString(p.PublicPort) pp = U.PortString(p.PublicPort)
} }
} }
// replace public port with private port if using private IP. // replace public port with private port if using private IP.
if e.Host == e.PrivateIP { if e.Host == cont.PrivateIP {
if p, ok := e.PublicPortMapping[pp]; ok { if p, ok := cont.PublicPortMapping[pp]; ok {
pp = U.PortString(p.PrivatePort) pp = U.PortString(p.PrivatePort)
} }
} }
if e.Scheme == "" && isDocker { if e.Scheme == "" && isDocker {
switch { switch {
case e.Host == e.PublicIP && e.PublicPortMapping[pp].Type == "udp": case e.Host == cont.PublicIP && cont.PublicPortMapping[pp].Type == "udp":
e.Scheme = "udp" e.Scheme = "udp"
case e.Host == e.PrivateIP && e.PrivatePortMapping[pp].Type == "udp": case e.Host == cont.PrivateIP && cont.PrivatePortMapping[pp].Type == "udp":
e.Scheme = "udp" e.Scheme = "udp"
} }
} }
@ -127,17 +128,17 @@ func (e *RawEntry) FillMissingFields() {
if e.HealthCheck.Timeout == 0 { if e.HealthCheck.Timeout == 0 {
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
} }
if e.IdleTimeout == "" { if cont.IdleTimeout == "" {
e.IdleTimeout = common.IdleTimeoutDefault cont.IdleTimeout = common.IdleTimeoutDefault
} }
if e.WakeTimeout == "" { if cont.WakeTimeout == "" {
e.WakeTimeout = common.WakeTimeoutDefault cont.WakeTimeout = common.WakeTimeoutDefault
} }
if e.StopTimeout == "" { if cont.StopTimeout == "" {
e.StopTimeout = common.StopTimeoutDefault cont.StopTimeout = common.StopTimeoutDefault
} }
if e.StopMethod == "" { if cont.StopMethod == "" {
e.StopMethod = common.StopMethodDefault cont.StopMethod = common.StopMethodDefault
} }
e.Port = joinPorts(lp, pp, extra) e.Port = joinPorts(lp, pp, extra)

View file

@ -99,9 +99,53 @@ func (p BidirectionalPipe) Start() error {
return b.Build().Error() return b.Build().Error()
} }
func Copy(dst *ContextWriter, src *ContextReader) error { // Copyright 2009 The Go Authors. All rights reserved.
_, err := io.Copy(dst, src) // Use of this source code is governed by a BSD-style
return err // This is a copy of io.Copy with context handling
// Author: yusing <yusing@6uo.me>
func Copy(dst *ContextWriter, src *ContextReader) (err error) {
size := 32 * 1024
if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N {
if l.N < 1 {
size = 1
} else {
size = int(l.N)
}
}
buf := make([]byte, size)
for {
select {
case <-src.ctx.Done():
return src.ctx.Err()
case <-dst.ctx.Done():
return dst.ctx.Err()
default:
nr, er := src.Reader.Read(buf)
if nr > 0 {
nw, ew := dst.Writer.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write result")
}
}
if ew != nil {
err = ew
return
}
if nr != nw {
err = io.ErrShortWrite
return
}
}
if er != nil {
if er != io.EOF {
err = er
}
return
}
}
}
} }
func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error {

View file

@ -99,12 +99,10 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
continue // Ignore this field if the tag is "-" continue // Ignore this field if the tag is "-"
} }
if strings.Contains(jsonTag, ",omitempty") { if strings.Contains(jsonTag, ",omitempty") {
if field.Type.Kind() == reflect.Ptr && value.Field(i).IsNil() {
continue
}
if value.Field(i).IsZero() { if value.Field(i).IsZero() {
continue continue
} }
jsonTag = strings.Replace(jsonTag, ",omitempty", "", 1)
} }
// If the json tag is not empty, use it as the key // If the json tag is not empty, use it as the key

View file

@ -7,7 +7,7 @@ import (
) )
type HealthCheckConfig struct { type HealthCheckConfig struct {
Disabled bool `json:"disabled" yaml:"disabled"` Disabled bool `json:"disabled,omitempty" yaml:"disabled"`
Path string `json:"path,omitempty" yaml:"path"` Path string `json:"path,omitempty" yaml:"path"`
UseGet bool `json:"use_get,omitempty" yaml:"use_get"` UseGet bool `json:"use_get,omitempty" yaml:"use_get"`
Interval time.Duration `json:"interval" yaml:"interval"` Interval time.Duration `json:"interval" yaml:"interval"`

View file

@ -0,0 +1,30 @@
package health
import (
"encoding/json"
"time"
"github.com/yusing/go-proxy/internal/net/types"
)
type JSONRepresentation struct {
Name string
Config *HealthCheckConfig
Status Status
Started time.Time
Uptime time.Duration
URL types.URL
Extra map[string]any
}
func (jsonRepr *JSONRepresentation) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"name": jsonRepr.Name,
"config": jsonRepr.Config,
"started": jsonRepr.Started.Unix(),
"status": jsonRepr.Status.String(),
"uptime": jsonRepr.Uptime.Seconds(),
"url": jsonRepr.URL.String(),
"extra": jsonRepr.Extra,
})
}

View file

@ -2,7 +2,6 @@ package health
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"sync" "sync"
"time" "time"
@ -25,8 +24,9 @@ type (
} }
HealthCheckFunc func() (healthy bool, detail string, err error) HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct { monitor struct {
config *HealthCheckConfig service string
url types.URL config *HealthCheckConfig
url types.URL
status U.AtomicValue[Status] status U.AtomicValue[Status]
checkHealth HealthCheckFunc checkHealth HealthCheckFunc
@ -43,8 +43,10 @@ type (
var monMap = F.NewMapOf[string, HealthMonitor]() var monMap = F.NewMapOf[string, HealthMonitor]()
func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
task, cancel := task.SubtaskWithCancel("Health monitor for %s", task.Name()) service := task.Name()
task, cancel := task.SubtaskWithCancel("Health monitor for %s", service)
mon := &monitor{ mon := &monitor{
service: service,
config: config, config: config,
url: url, url: url,
checkHealth: healthCheckFunc, checkHealth: healthCheckFunc,
@ -57,17 +59,12 @@ func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, heal
return mon return mon
} }
func Inspect(name string) (status Status, ok bool) { func Inspect(name string) (HealthMonitor, bool) {
mon, ok := monMap.Load(name) return monMap.Load(name)
if !ok {
return
}
return mon.Status(), true
} }
func (mon *monitor) Start() { func (mon *monitor) Start() {
defer monMap.Store(mon.task.Name(), mon) defer monMap.Store(mon.task.Name(), mon)
defer logger.Debugf("%s health monitor started", mon.String())
go func() { go func() {
defer close(mon.done) defer close(mon.done)
@ -93,12 +90,9 @@ func (mon *monitor) Start() {
} }
} }
}() }()
logger.Debugf("health monitor %q started", mon.String())
} }
func (mon *monitor) Stop() { func (mon *monitor) Stop() {
defer logger.Debugf("%s health monitor stopped", mon.String())
monMap.Delete(mon.task.Name()) monMap.Delete(mon.task.Name())
mon.mu.Lock() mon.mu.Lock()
@ -132,14 +126,14 @@ func (mon *monitor) String() string {
} }
func (mon *monitor) MarshalJSON() ([]byte, error) { func (mon *monitor) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{ return (&JSONRepresentation{
"name": mon.Name(), Name: mon.Name(),
"url": mon.url, Config: mon.config,
"status": mon.status.Load(), Status: mon.status.Load(),
"uptime": mon.Uptime().String(), Started: mon.startTime,
"started": mon.startTime.Unix(), Uptime: mon.Uptime(),
"config": mon.config, URL: mon.url,
}) }).MarshalJSON()
} }
func (mon *monitor) checkUpdateHealth() (hasError bool) { func (mon *monitor) checkUpdateHealth() (hasError bool) {
@ -147,7 +141,7 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
if err != nil { if err != nil {
mon.status.Store(StatusError) mon.status.Store(StatusError)
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
logger.Errorf("%s failed to check health: %s", mon.String(), err) logger.Errorf("%s failed to check health: %s", mon.service, err)
} }
mon.Stop() mon.Stop()
return false return false
@ -160,9 +154,9 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
} }
if healthy != (mon.status.Swap(status) == StatusHealthy) { if healthy != (mon.status.Swap(status) == StatusHealthy) {
if healthy { if healthy {
logger.Infof("%s is up", mon.String()) logger.Infof("%s is up", mon.service)
} else { } else {
logger.Warnf("%s is down: %s", mon.String(), detail) logger.Warnf("%s is down: %s", mon.service, detail)
} }
} }

View file

@ -1,7 +1,5 @@
package health package health
import "encoding/json"
type Status int type Status int
const ( const (
@ -36,7 +34,7 @@ func (s Status) String() string {
} }
func (s Status) MarshalJSON() ([]byte, error) { func (s Status) MarshalJSON() ([]byte, error) {
return json.Marshal(s.String()) return []byte(`"` + s.String() + `"`), nil
} }
func (s Status) Good() bool { func (s Status) Good() bool {