Improved healthcheck, idlewatcher support for loadbalanced routes, bug fixes

This commit is contained in:
yusing 2024-10-15 15:34:27 +08:00
parent 53fa28ae77
commit f4d532598c
34 changed files with 568 additions and 423 deletions

2
.gitignore vendored
View file

@ -1,6 +1,8 @@
compose.yml
*.compose.yml
config
certs
config*/
certs*/
bin/

8
go.mod
View file

@ -41,12 +41,12 @@ require (
github.com/ovh/go-ovh v1.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.30.0 // indirect
go.opentelemetry.io/otel/metric v1.31.0 // indirect
go.opentelemetry.io/otel/sdk v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.30.0 // indirect
go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/mod v0.21.0 // indirect
golang.org/x/oauth2 v0.23.0 // indirect

16
go.sum
View file

@ -96,20 +96,20 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0 h1:ZIg3ZT/aQ7AfKqdwp7ECpOK6vHqquXXuyTjIO8ZdmPs=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.55.0/go.mod h1:DQAwmETtZV00skUwgD6+0U89g80NKsJE3DCKeLLPQMI=
go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts=
go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM=
go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0 h1:lsInsfvhVIfOI6qHVyysXMNDnjO9Npvl7tlDPJFBVd4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.30.0/go.mod h1:KQsVNh4OjgjTG0G6EiNi1jVpnaeeKsKMRwbLN+f1+8M=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 h1:umZgi92IyxfXd/l4kaDhnKgY8rnN/cZcF1LKc6I8OQ8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0/go.mod h1:4lVs6obhSVRb1EW5FhOuBTyiQhtRtAnnva9vD3yRfq8=
go.opentelemetry.io/otel/metric v1.30.0 h1:4xNulvn9gjzo4hjg+wzIKG7iNFEaBMX00Qd4QIZs7+w=
go.opentelemetry.io/otel/metric v1.30.0/go.mod h1:aXTfST94tswhWEb+5QjlSqG+cZlmyXy/u8jFpor3WqQ=
go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
go.opentelemetry.io/otel/sdk v1.30.0 h1:cHdik6irO49R5IysVhdn8oaiR9m8XluDaJAs4DfOrYE=
go.opentelemetry.io/otel/sdk v1.30.0/go.mod h1:p14X4Ok8S+sygzblytT1nqG98QG2KYKv++HE0LY/mhg=
go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc=
go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o=
go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View file

@ -15,10 +15,15 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
return
}
status, ok := health.Inspect(target)
result, ok := health.Inspect(target)
if !ok {
HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound)
return
}
WriteBody(w, []byte(status.String()))
json, err := result.MarshalJSON()
if err != nil {
HandleErr(w, r, err)
return
}
RespondJSON(w, r, json)
}

View file

@ -8,6 +8,7 @@ import (
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/utils"
)
@ -45,16 +46,7 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
}
func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) {
routes := cfg.RoutesByAlias()
typeFilter := r.FormValue("type")
if typeFilter != "" {
for k, v := range routes {
if v["type"] != typeFilter {
delete(routes, k)
}
}
}
routes := cfg.RoutesByAlias(route.RouteType(r.FormValue("type")))
U.RespondJSON(w, r, routes)
}

View file

@ -212,9 +212,9 @@ func GlobalContextWait(timeout time.Duration) {
case <-done:
return
case <-after:
logrus.Println("Timeout waiting for these tasks to finish:")
logrus.Warnln("Timeout waiting for these tasks to finish:")
tasksMap.Range(func(t *task, _ struct{}) bool {
logrus.Println(t.tree())
logrus.Warnln(t.tree())
return true
})
return

View file

@ -87,9 +87,8 @@ func (cfg *Config) StartProxyProviders() {
func (cfg *Config) WatchChanges() {
task := common.NewTask("Config watcher")
defer task.Finished()
go func() {
defer task.Finished()
for {
select {
case <-task.Context().Done():

View file

@ -42,25 +42,18 @@ func (cfg *Config) HomepageConfig() homepage.Config {
}
hpCfg := homepage.NewHomePageConfig()
cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) {
if !r.Started() {
return
}
entry := r.Entry
if entry.Homepage == nil {
entry.Homepage = &homepage.Item{
Show: r.Entry.IsExplicit || !p.IsExplicitOnly(),
}
}
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
entry := r.Raw
item := entry.Homepage
if item == nil {
item = new(homepage.Item)
}
if !item.Show && !item.IsEmpty() {
if !item.Show && item.IsEmpty() {
item.Show = true
}
if !item.Show || r.Type != R.RouteTypeReverseProxy {
if !item.Show {
return
}
@ -73,12 +66,17 @@ func (cfg *Config) HomepageConfig() homepage.Config {
)
}
if p.GetType() == PR.ProviderTypeDocker {
if r.IsDocker() {
if item.Category == "" {
item.Category = "Docker"
}
item.SourceType = string(PR.ProviderTypeDocker)
} else if p.GetType() == PR.ProviderTypeFile {
} else if r.UseLoadBalance() {
if item.Category == "" {
item.Category = "Load-balanced"
}
item.SourceType = "loadbalancer"
} else {
if item.Category == "" {
item.Category = "Others"
}
@ -97,28 +95,20 @@ func (cfg *Config) HomepageConfig() homepage.Config {
return hpCfg
}
func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]U.SerializedObject {
routes := make(map[string]U.SerializedObject)
if len(typeFilter) == 0 {
func (cfg *Config) RoutesByAlias(typeFilter ...R.RouteType) map[string]any {
routes := make(map[string]any)
if len(typeFilter) == 0 || typeFilter[0] == "" {
typeFilter = []R.RouteType{R.RouteTypeReverseProxy, R.RouteTypeStream}
}
for _, t := range typeFilter {
switch t {
case R.RouteTypeReverseProxy:
R.GetReverseProxies().RangeAll(func(alias string, r *R.HTTPRoute) {
obj, err := U.Serialize(r)
if err != nil {
panic(err) // should not happen
}
routes[alias] = obj
routes[alias] = r
})
case R.RouteTypeStream:
R.GetStreamProxies().RangeAll(func(alias string, r *R.StreamRoute) {
obj, err := U.Serialize(r)
if err != nil {
panic(err) // should not happen
}
routes[alias] = obj
routes[alias] = r
})
}
}

View file

@ -43,7 +43,9 @@ func init() {
select {
case <-task.Context().Done():
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()
}
})
clientMap.Clear()
return

View file

@ -41,6 +41,8 @@ type (
}
)
var DummyContainer = new(Container)
func FromDocker(c *types.Container, dockerHost string) (res *Container) {
isExplicit := c.Labels[LabelAliases] != ""
helper := containerHelper{c}

View file

@ -2,7 +2,6 @@ package idlewatcher
import (
"context"
"encoding/json"
"net/http"
"strconv"
"time"
@ -73,15 +72,15 @@ func (w *Waker) Uptime() time.Duration {
}
func (w *Waker) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"name": w.Name(),
"url": w.URL,
"status": w.Status(),
"config": health.HealthCheckConfig{
return (&health.JSONRepresentation{
Name: w.Name(),
Status: w.Status(),
Config: &health.HealthCheckConfig{
Interval: w.IdleTimeout,
Timeout: w.WakeTimeout,
},
})
URL: w.URL,
}).MarshalJSON()
}
/* End of HealthMonitor interface */
@ -89,6 +88,10 @@ func (w *Waker) MarshalJSON() ([]byte, error) {
func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is ready
if w.ready.Load() {
return true
@ -115,6 +118,16 @@ func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool)
return
}
select {
case <-w.task.Context().Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
case <-ctx.Done():
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return
default:
}
// wake the container and reset idle timer
// also wait for another wake request
w.wakeCh <- struct{}{}
@ -169,3 +182,8 @@ func (w *Waker) wake(rw http.ResponseWriter, r *http.Request) (shouldNext bool)
time.Sleep(100 * time.Millisecond)
}
}
// static HealthMonitor interface check
func (w *Waker) _() health.HealthMonitor {
return w
}

View file

@ -8,10 +8,12 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher"
)
@ -29,9 +31,10 @@ type (
wakeDone chan E.NestedError
ticker *time.Ticker
ctx context.Context
task common.Task
cancel context.CancelFunc
refCount *sync.WaitGroup
refCount *U.RefCount
l logrus.FieldLogger
}
@ -42,17 +45,11 @@ type (
)
var (
mainLoopCtx context.Context
mainLoopCancel context.CancelFunc
mainLoopWg sync.WaitGroup
watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex
portHistoryMap = F.NewMapOf[PT.Alias, string]()
newWatcherCh = make(chan *Watcher)
logger = logrus.WithField("module", "idle_watcher")
)
@ -73,7 +70,7 @@ func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
}
if w, ok := watcherMap.Load(key); ok {
w.refCount.Add(1)
w.refCount.Add()
w.ReverseProxyEntry = entry
return w, nil
}
@ -86,83 +83,51 @@ func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) {
w := &Watcher{
ReverseProxyEntry: entry,
client: client,
refCount: &sync.WaitGroup{},
refCount: U.NewRefCounter(),
wakeCh: make(chan struct{}, 1),
wakeDone: make(chan E.NestedError),
ticker: time.NewTicker(entry.IdleTimeout),
l: logger.WithField("container", entry.ContainerName),
}
w.refCount.Add(1)
w.task, w.cancel = common.NewTaskWithCancel("Idlewatcher for %s", w.Alias)
w.stopByMethod = w.getStopCallback()
watcherMap.Store(key, w)
go func() {
newWatcherCh <- w
}()
go w.watchUntilCancel()
return w, nil
}
func (w *Watcher) Unregister() {
w.refCount.Add(-1)
}
func Start() {
logger.Debug("started")
defer logger.Debug("stopped")
mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())
for {
select {
case <-mainLoopCtx.Done():
return
case w := <-newWatcherCh:
w.l.Debug("registered")
mainLoopWg.Add(1)
go func() {
w.watchUntilCancel()
w.refCount.Wait() // wait for 0 ref count
watcherMap.Delete(w.ContainerID)
w.l.Debug("unregistered")
mainLoopWg.Done()
}()
}
}
}
func Stop() {
mainLoopCancel()
mainLoopWg.Wait()
w.refCount.Sub()
}
func (w *Watcher) containerStop() error {
return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{
return w.client.ContainerStop(w.task.Context(), w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal),
Timeout: &w.StopTimeout,
})
}
func (w *Watcher) containerPause() error {
return w.client.ContainerPause(w.ctx, w.ContainerID)
return w.client.ContainerPause(w.task.Context(), w.ContainerID)
}
func (w *Watcher) containerKill() error {
return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal))
return w.client.ContainerKill(w.task.Context(), w.ContainerID, string(w.StopSignal))
}
func (w *Watcher) containerUnpause() error {
return w.client.ContainerUnpause(w.ctx, w.ContainerID)
return w.client.ContainerUnpause(w.task.Context(), w.ContainerID)
}
func (w *Watcher) containerStart() error {
return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{})
return w.client.ContainerStart(w.task.Context(), w.ContainerID, container.StartOptions{})
}
func (w *Watcher) containerStatus() (string, E.NestedError) {
json, err := w.client.ContainerInspect(w.ctx, w.ContainerID)
json, err := w.client.ContainerInspect(w.task.Context(), w.ContainerID)
if err != nil {
return "", E.FailWith("inspect container", err)
}
@ -221,12 +186,8 @@ func (w *Watcher) resetIdleTimer() {
}
func (w *Watcher) watchUntilCancel() {
defer close(w.wakeCh)
w.ctx, w.cancel = context.WithCancel(mainLoopCtx)
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.task.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID),
@ -238,13 +199,23 @@ func (w *Watcher) watchUntilCancel() {
W.DockerFilterUnpause,
),
})
defer w.ticker.Stop()
defer w.client.Close()
defer func() {
w.ticker.Stop()
w.client.Close()
close(w.wakeDone)
close(w.wakeCh)
watcherMap.Delete(w.ContainerID)
w.task.Finished()
}()
for {
select {
case <-w.ctx.Done():
w.l.Debug("stopped")
case <-w.task.Context().Done():
w.l.Debug("stopped by context done")
return
case <-w.refCount.Zero():
w.l.Debug("stopped by zero ref count")
return
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {

View file

@ -7,6 +7,17 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) {
client, err := ConnectClient(dockerHost)
defer client.Close()
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
}
return client.Inspect(containerID)
}
func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

View file

@ -46,7 +46,18 @@ func TestBuilderNested(t *testing.T) {
- invalid Inner: "2"
- Action 2 failed:
- invalid Inner: "3"`)
if got != expected1 && got != expected2 {
t.Errorf("expected \n%s, got \n%s", expected1, got)
}
ExpectEqualAny(t, got, []string{expected1, expected2})
}
func TestBuilderTo(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Addf("abcd")
var err NestedError
eb.To(&err)
got := err.String()
expected := (`error occurred:
- abcd`)
ExpectEqual(t, got, expected)
}

View file

@ -6,8 +6,8 @@ import (
"time"
"github.com/go-acme/lego/v4/log"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// TODO: stats of each server.
@ -41,13 +41,14 @@ type (
const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{Config: cfg, pool: servers{}}
mode := cfg.Mode
if !cfg.Mode.ValidateUpdate() {
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %s", cfg.Link, mode, cfg.Mode)
}
switch mode {
case RoundRobin:
lb := &LoadBalancer{Config: new(Config), pool: make(servers, 0)}
lb.UpdateConfigIfNeeded(cfg)
return lb
}
func (lb *LoadBalancer) updateImpl() {
switch lb.Mode {
case Unset, RoundRobin:
lb.impl = lb.newRoundRobin()
case LeastConn:
lb.impl = lb.newLeastConn()
@ -56,7 +57,34 @@ func New(cfg *Config) *LoadBalancer {
default: // should happen in test only
lb.impl = lb.newRoundRobin()
}
return lb
for _, srv := range lb.pool {
lb.impl.OnAddServer(srv)
}
}
func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if cfg != nil {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.Link = cfg.Link
if lb.Mode == Unset && cfg.Mode != Unset {
lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() {
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode)
}
lb.updateImpl()
}
if len(lb.Options) == 0 && len(cfg.Options) > 0 {
lb.Options = cfg.Options
}
}
if lb.impl == nil {
lb.updateImpl()
}
}
func (lb *LoadBalancer) AddServer(srv *Server) {
@ -66,6 +94,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.pool = append(lb.pool, srv)
lb.sumWeight += srv.Weight
lb.Rebalance()
lb.impl.OnAddServer(srv)
logger.Debugf("[add] loadbalancer %s: %d servers available", lb.Link, len(lb.pool))
}
@ -74,6 +103,8 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
lb.sumWeight -= srv.Weight
lb.Rebalance()
lb.impl.OnRemoveServer(srv)
for i, s := range lb.pool {
@ -87,7 +118,6 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
return
}
lb.Rebalance()
logger.Debugf("[remove] loadbalancer %s: %d servers left", lb.Link, len(lb.pool))
}
@ -152,20 +182,6 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
func (lb *LoadBalancer) Start() {
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
for _, s := range lb.pool {
msg.Addf("%s: %d", s.Name, s.Weight)
}
lb.Rebalance()
inner := E.NewBuilder("after rebalancing")
for _, s := range lb.pool {
inner.Addf("%s: %d", s.Name, s.Weight)
}
msg.Addf("%s", inner)
logger.Warn(msg)
}
if lb.sumWeight != 0 {
log.Warnf("weighted mode not supported yet")
}
@ -186,6 +202,45 @@ func (lb *LoadBalancer) Uptime() time.Duration {
return time.Since(lb.startTime)
}
// MarshalJSON implements health.HealthMonitor.
func (lb *LoadBalancer) MarshalJSON() ([]byte, error) {
extra := make(map[string]any)
for _, v := range lb.pool {
extra[v.Name] = v.healthMon
}
return (&health.JSONRepresentation{
Name: lb.Name(),
Status: lb.Status(),
Started: lb.startTime,
Uptime: lb.Uptime(),
Extra: map[string]any{
"config": lb.Config,
"pool": extra,
},
}).MarshalJSON()
}
// Name implements health.HealthMonitor.
func (lb *LoadBalancer) Name() string {
return lb.Link
}
// Status implements health.HealthMonitor.
func (lb *LoadBalancer) Status() health.Status {
if len(lb.pool) == 0 {
return health.StatusUnknown
}
if len(lb.availServers()) == 0 {
return health.StatusUnhealthy
}
return health.StatusHealthy
}
// String implements health.HealthMonitor.
func (lb *LoadBalancer) String() string {
return lb.Name()
}
func (lb *LoadBalancer) availServers() servers {
lb.poolMu.Lock()
defer lb.poolMu.Unlock()
@ -199,3 +254,8 @@ func (lb *LoadBalancer) availServers() servers {
}
return avail
}
// static HealthMonitor interface check
func (lb *LoadBalancer) _() health.HealthMonitor {
return lb
}

View file

@ -7,6 +7,7 @@ import (
type Mode string
const (
Unset Mode = ""
RoundRobin Mode = "roundrobin"
LeastConn Mode = "leastconn"
IPHash Mode = "iphash"
@ -14,7 +15,9 @@ const (
func (mode *Mode) ValidateUpdate() bool {
switch U.ToLowerNoSnake(string(*mode)) {
case "", string(RoundRobin):
case "":
return true
case string(RoundRobin):
*mode = RoundRobin
return true
case string(LeastConn):

View file

@ -16,32 +16,36 @@ import (
type (
ReverseProxyEntry struct { // real model after validation
Alias T.Alias `json:"alias"`
Scheme T.Scheme `json:"scheme"`
URL net.URL `json:"url"`
Raw *types.RawEntry `json:"raw"`
Alias T.Alias `json:"alias,omitempty"`
Scheme T.Scheme `json:"scheme,omitempty"`
URL net.URL `json:"url,omitempty"`
NoTLSVerify bool `json:"no_tls_verify,omitempty"`
PathPatterns T.PathPatterns `json:"path_patterns"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck"`
PathPatterns T.PathPatterns `json:"path_patterns,omitempty"`
HealthCheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
LoadBalance *loadbalancer.Config `json:"load_balance,omitempty"`
Middlewares D.NestedLabelMap `json:"middlewares,omitempty"`
/* Docker only */
IdleTimeout time.Duration `json:"idle_timeout"`
WakeTimeout time.Duration `json:"wake_timeout"`
StopMethod T.StopMethod `json:"stop_method"`
StopTimeout int `json:"stop_timeout"`
IdleTimeout time.Duration `json:"idle_timeout,omitempty"`
WakeTimeout time.Duration `json:"wake_timeout,omitempty"`
StopMethod T.StopMethod `json:"stop_method,omitempty"`
StopTimeout int `json:"stop_timeout,omitempty"`
StopSignal T.Signal `json:"stop_signal,omitempty"`
DockerHost string `json:"docker_host"`
ContainerName string `json:"container_name"`
ContainerID string `json:"container_id"`
ContainerRunning bool `json:"container_running"`
DockerHost string `json:"docker_host,omitempty"`
ContainerName string `json:"container_name,omitempty"`
ContainerID string `json:"container_id,omitempty"`
ContainerRunning bool `json:"container_running,omitempty"`
}
StreamEntry struct {
Alias T.Alias `json:"alias"`
Scheme T.StreamScheme `json:"scheme"`
Host T.Host `json:"host"`
Port T.StreamPort `json:"port"`
Healthcheck *health.HealthCheckConfig `json:"healthcheck"`
Raw *types.RawEntry `json:"raw"`
Alias T.Alias `json:"alias,omitempty"`
Scheme T.StreamScheme `json:"scheme,omitempty"`
Host T.Host `json:"host,omitempty"`
Port T.StreamPort `json:"port,omitempty"`
Healthcheck *health.HealthCheckConfig `json:"healthcheck,omitempty"`
}
)
@ -88,6 +92,10 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) {
func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry {
var stopTimeOut time.Duration
cont := m.Container
if cont == nil {
cont = D.DummyContainer
}
host, err := T.ValidateHost(m.Host)
b.Add(err)
@ -101,21 +109,21 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout)
idleTimeout, err := T.ValidateDurationPostitive(cont.IdleTimeout)
b.Add(err)
wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout)
wakeTimeout, err := T.ValidateDurationPostitive(cont.WakeTimeout)
b.Add(err)
stopMethod, err := T.ValidateStopMethod(m.StopMethod)
stopMethod, err := T.ValidateStopMethod(cont.StopMethod)
b.Add(err)
if stopMethod == T.StopMethodStop {
stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout)
stopTimeOut, err = T.ValidateDurationPostitive(cont.StopTimeout)
b.Add(err)
}
stopSignal, err := T.ValidateSignal(m.StopSignal)
stopSignal, err := T.ValidateSignal(cont.StopSignal)
b.Add(err)
if err != nil {
@ -123,6 +131,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
}
return &ReverseProxyEntry{
Raw: m,
Alias: T.NewAlias(m.Alias),
Scheme: s,
URL: net.NewURL(url),
@ -136,10 +145,10 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn
StopMethod: stopMethod,
StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument
StopSignal: stopSignal,
DockerHost: m.DockerHost,
ContainerName: m.ContainerName,
ContainerID: m.ContainerID,
ContainerRunning: m.Running,
DockerHost: cont.DockerHost,
ContainerName: cont.ContainerName,
ContainerID: cont.ContainerID,
ContainerRunning: cont.Running,
}
}
@ -158,6 +167,7 @@ func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry {
}
return &StreamEntry{
Raw: m,
Alias: T.NewAlias(m.Alias),
Scheme: *scheme,
Host: host,

View file

@ -32,7 +32,7 @@ func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
}
func (s StreamScheme) String() string {
return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme)
return fmt.Sprintf("%s -> %s", s.ListeningScheme, s.ProxyScheme)
}
// IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal.

View file

@ -72,7 +72,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
}
entries.RangeAll(func(_ string, e *types.RawEntry) {
e.DockerHost = p.dockerHost
e.Container.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
@ -88,7 +88,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
strings.HasSuffix(container.ContainerName, "-old")
}
func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) {
func (p *DockerProvider) OnEvent(event W.Event, oldRoutes R.Routes) (res EventResult) {
switch event.Action {
case events.ActionContainerStart, events.ActionContainerStop:
break
@ -98,74 +98,65 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul
b := E.NewBuilder("event %s error", event)
defer b.To(&res.err)
routes.RangeAll(func(k string, v *R.Route) {
if v.Entry.ContainerID == event.ActorID ||
v.Entry.ContainerName == event.ActorName {
b.Add(v.Stop())
routes.Delete(k)
res.nRemoved++
matches := R.NewRoutes()
oldRoutes.RangeAllParallel(func(k string, v *R.Route) {
if v.Entry.Container.ContainerID == event.ActorID ||
v.Entry.Container.ContainerName == event.ActorName {
matches.Store(k, v)
}
})
if res.nRemoved == 0 { // id & container name changed
// load all routes (rescan)
routesNew, err := p.LoadRoutesImpl()
routesOld := routes
if routesNew.Size() == 0 {
b.Add(E.FailWith("rescan routes", err))
return
}
routesNew.Range(func(k string, v *R.Route) bool {
if !routesOld.Has(k) {
routesOld.Store(k, v)
b.Add(v.Start())
res.nAdded++
return false
}
return true
})
routesOld.Range(func(k string, v *R.Route) bool {
if !routesNew.Has(k) {
b.Add(v.Stop())
routesOld.Delete(k)
res.nRemoved++
return false
}
return true
})
return
}
var newRoutes R.Routes
var err E.NestedError
client, err := D.ConnectClient(p.dockerHost)
if err != nil {
b.Add(E.FailWith("connect to docker", err))
return
}
defer client.Close()
cont, err := client.Inspect(event.ActorID)
if matches.Size() == 0 { // id & container name changed
matches = oldRoutes
newRoutes, err = p.LoadRoutesImpl()
b.Add(err)
} else {
cont, err := D.Inspect(p.dockerHost, event.ActorID)
if err != nil {
b.Add(E.FailWith("inspect container", err))
return
}
if p.shouldIgnore(cont) {
// stop all old routes
matches.RangeAllParallel(func(_ string, v *R.Route) {
b.Add(v.Stop())
})
return
}
entries, err := p.entriesFromContainerLabels(cont)
b.Add(err)
entries.RangeAll(func(alias string, entry *types.RawEntry) {
if routes.Has(alias) {
b.Add(E.Duplicated("alias", alias))
} else {
if route, err := R.NewRoute(entry); err != nil {
newRoutes, err = R.FromEntries(entries)
b.Add(err)
} else {
routes.Store(alias, route)
b.Add(route.Start())
res.nAdded++
}
matches.RangeAll(func(k string, v *R.Route) {
if !newRoutes.Has(k) && !oldRoutes.Has(k) {
b.Add(v.Stop())
matches.Delete(k)
res.nRemoved++
}
})
newRoutes.RangeAll(func(alias string, newRoute *R.Route) {
oldRoute, exists := oldRoutes.Load(alias)
if exists {
if err := oldRoute.Stop(); err != nil {
b.Add(err)
}
}
oldRoutes.Store(alias, newRoute)
if err := newRoute.Start(); err != nil {
b.Add(err)
}
if exists {
res.nReloaded++
} else {
res.nAdded++
}
})

View file

@ -88,20 +88,20 @@ func TestApplyLabelWildcard(t *testing.T) {
ExpectDeepEqual(t, a.Middlewares, middlewaresExpect)
ExpectEqual(t, len(b.Middlewares), 0)
ExpectEqual(t, a.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.Container.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, b.Container.IdleTimeout, common.IdleTimeoutDefault)
ExpectEqual(t, a.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, a.Container.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, b.Container.StopTimeout, common.StopTimeoutDefault)
ExpectEqual(t, a.StopMethod, common.StopMethodDefault)
ExpectEqual(t, b.StopMethod, common.StopMethodDefault)
ExpectEqual(t, a.Container.StopMethod, common.StopMethodDefault)
ExpectEqual(t, b.Container.StopMethod, common.StopMethodDefault)
ExpectEqual(t, a.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, b.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, a.Container.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, b.Container.WakeTimeout, common.WakeTimeoutDefault)
ExpectEqual(t, a.StopSignal, "SIGTERM")
ExpectEqual(t, b.StopSignal, "SIGTERM")
ExpectEqual(t, a.Container.StopSignal, "SIGTERM")
ExpectEqual(t, b.Container.StopSignal, "SIGTERM")
}
func TestApplyLabelWithAlias(t *testing.T) {
@ -186,16 +186,16 @@ func TestPublicIPLocalhost(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, client.DefaultDockerHost)
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
ExpectEqual(t, raw.PublicIP, "127.0.0.1")
ExpectEqual(t, raw.Host, raw.PublicIP)
ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1")
ExpectEqual(t, raw.Host, raw.Container.PublicIP)
}
func TestPublicIPRemote(t *testing.T) {
c := D.FromDocker(&types.Container{Names: dummyNames}, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
ExpectEqual(t, raw.PublicIP, "1.2.3.4")
ExpectEqual(t, raw.Host, raw.PublicIP)
ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
ExpectEqual(t, raw.Host, raw.Container.PublicIP)
}
func TestPrivateIPLocalhost(t *testing.T) {
@ -211,8 +211,8 @@ func TestPrivateIPLocalhost(t *testing.T) {
}, client.DefaultDockerHost)
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
ExpectEqual(t, raw.PrivateIP, "172.17.0.123")
ExpectEqual(t, raw.Host, raw.PrivateIP)
ExpectEqual(t, raw.Container.PrivateIP, "172.17.0.123")
ExpectEqual(t, raw.Host, raw.Container.PrivateIP)
}
func TestPrivateIPRemote(t *testing.T) {
@ -228,9 +228,9 @@ func TestPrivateIPRemote(t *testing.T) {
}, "tcp://1.2.3.4:2375")
raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a")
ExpectTrue(t, ok)
ExpectEqual(t, raw.PrivateIP, "")
ExpectEqual(t, raw.PublicIP, "1.2.3.4")
ExpectEqual(t, raw.Host, raw.PublicIP)
ExpectEqual(t, raw.Container.PrivateIP, "")
ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4")
ExpectEqual(t, raw.Host, raw.Container.PublicIP)
}
func TestStreamDefaultValues(t *testing.T) {

View file

@ -5,6 +5,7 @@ import (
"path"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
R "github.com/yusing/go-proxy/internal/route"
W "github.com/yusing/go-proxy/internal/watcher"
@ -19,7 +20,7 @@ type (
routes R.Routes
watcher W.Watcher
watcherCtx context.Context
watcherTask common.Task
watcherCancel context.CancelFunc
l *logrus.Entry
@ -38,8 +39,9 @@ type (
Type ProviderType `json:"type"`
}
EventResult struct {
nRemoved int
nAdded int
nRemoved int
nReloaded int
err E.NestedError
}
)
@ -129,6 +131,7 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) {
p.routes.RangeAllParallel(func(alias string, r *R.Route) {
errors.Add(r.Stop().Subject(r))
})
p.routes.Clear()
return
}
@ -175,27 +178,21 @@ func (p *Provider) Statistics() ProviderStats {
}
func (p *Provider) watchEvents() {
p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background())
events, errs := p.watcher.Events(p.watcherCtx)
p.watcherTask, p.watcherCancel = common.NewTaskWithCancel("Watcher for provider %s", p.name)
defer p.watcherTask.Finished()
events, errs := p.watcher.Events(p.watcherTask.Context())
l := p.l.WithField("module", "watcher")
for {
select {
case <-p.watcherCtx.Done():
case <-p.watcherTask.Context().Done():
return
case event := <-events:
res := p.OnEvent(event, p.routes)
if res.nAdded+res.nRemoved+res.nReloaded > 0 {
l.Infof("%s event %q", event.Type, event)
if res.nAdded > 0 || res.nRemoved > 0 {
n := res.nAdded - res.nRemoved
switch {
case n == 0:
l.Infof("%d route(s) reloaded", res.nAdded)
case n > 0:
l.Infof("%d route(s) added", n)
default:
l.Infof("%d route(s) removed", -n)
}
l.Infof("| %d NEW | %d REMOVED | %d RELOADED |", res.nAdded, res.nRemoved, res.nReloaded)
}
if res.err != nil {
l.Error(res.err)

View file

@ -18,6 +18,7 @@ import (
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/watcher/health"
)
@ -26,9 +27,9 @@ type (
HTTPRoute struct {
*P.ReverseProxyEntry `json:"entry"`
LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer,omitempty"`
HealthMon health.HealthMonitor `json:"health"`
HealthMon health.HealthMonitor `json:"health,omitempty"`
loadBalancer *loadbalancer.LoadBalancer
server *loadbalancer.Server
handler http.Handler
rp *gphttp.ReverseProxy
@ -102,10 +103,6 @@ func (r *HTTPRoute) URL() url.URL {
}
func (r *HTTPRoute) Start() E.NestedError {
if r.handler != nil {
return nil
}
if r.ShouldNotServe() {
return nil
}
@ -113,6 +110,10 @@ func (r *HTTPRoute) Start() E.NestedError {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.handler != nil {
return nil
}
if r.HealthCheck.Disabled && (r.UseIdleWatcher() || r.UseLoadBalance()) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias)
r.HealthCheck.Disabled = true
@ -129,7 +130,10 @@ func (r *HTTPRoute) Start() E.NestedError {
r.HealthMon = waker
case !r.HealthCheck.Disabled:
r.HealthMon = health.NewHTTPHealthMonitor(common.GlobalTask("Reverse proxy "+r.String()), r.URL(), r.HealthCheck)
fallthrough
}
if r.handler == nil {
switch {
case len(r.PathPatterns) == 1 && r.PathPatterns[0] == "/":
r.handler = ReverseProxyHandler{r.rp}
default:
@ -139,6 +143,11 @@ func (r *HTTPRoute) Start() E.NestedError {
}
r.handler = mux
}
}
if r.HealthMon != nil {
r.HealthMon.Start()
}
if r.UseLoadBalance() {
r.addToLoadBalancer()
@ -146,9 +155,6 @@ func (r *HTTPRoute) Start() E.NestedError {
httpRoutes.Store(string(r.Alias), r)
}
if r.HealthMon != nil {
r.HealthMon.Start()
}
return nil
}
@ -160,7 +166,7 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) {
httpRoutesMu.Lock()
defer httpRoutesMu.Unlock()
if r.LoadBalancer != nil {
if r.loadBalancer != nil {
r.removeFromLoadBalancer()
} else {
httpRoutes.Delete(string(r.Alias))
@ -184,29 +190,40 @@ func (r *HTTPRoute) addToLoadBalancer() {
var lb *loadbalancer.LoadBalancer
linked, ok := httpRoutes.Load(r.LoadBalance.Link)
if ok {
lb = linked.LoadBalancer
lb = linked.loadBalancer
lb.UpdateConfigIfNeeded(r.LoadBalance)
if linked.Raw.Homepage == nil && r.Raw.Homepage != nil {
linked.Raw.Homepage = r.Raw.Homepage
}
} else {
lb = loadbalancer.New(r.LoadBalance)
lb.Start()
linked = &HTTPRoute{
LoadBalancer: lb,
ReverseProxyEntry: &P.ReverseProxyEntry{
Raw: &types.RawEntry{
Homepage: r.Raw.Homepage,
},
Alias: PT.Alias(lb.Link),
},
HealthMon: lb,
loadBalancer: lb,
handler: lb,
}
httpRoutes.Store(r.LoadBalance.Link, linked)
}
r.LoadBalancer = lb
r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server)
}
func (r *HTTPRoute) removeFromLoadBalancer() {
r.LoadBalancer.RemoveServer(r.server)
if r.LoadBalancer.IsEmpty() {
r.loadBalancer.RemoveServer(r.server)
if r.loadBalancer.IsEmpty() {
httpRoutes.Delete(r.LoadBalance.Link)
logrus.Debugf("loadbalancer %q removed from route table", r.LoadBalance.Link)
}
r.server = nil
r.LoadBalancer = nil
r.loadBalancer = nil
}
func ProxyHandler(w http.ResponseWriter, r *http.Request) {

View file

@ -1,6 +1,7 @@
package route
import (
"github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
url "github.com/yusing/go-proxy/internal/net/types"
P "github.com/yusing/go-proxy/internal/proxy"
@ -36,6 +37,13 @@ const (
// function alias.
var NewRoutes = F.NewMap[Routes]
func (rt *Route) Container() *docker.Container {
if rt.Entry.Container == nil {
return docker.DummyContainer
}
return rt.Entry.Container
}
func NewRoute(en *types.RawEntry) (*Route, E.NestedError) {
entry, err := P.ValidateEntry(en)
if err != nil {

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"sync"
"github.com/sirupsen/logrus"
@ -24,11 +25,10 @@ type StreamRoute struct {
url url.URL
wg sync.WaitGroup
task common.Task
cancel context.CancelFunc
done chan struct{}
connCh chan any
l logrus.FieldLogger
mu sync.Mutex
@ -61,7 +61,6 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
base := &StreamRoute{
StreamEntry: entry,
url: url,
connCh: make(chan any, 100),
}
if entry.Scheme.ListeningScheme.IsTCP() {
base.StreamImpl = NewTCPRoute(base)
@ -73,7 +72,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) {
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias)
return fmt.Sprintf("stream %s", r.Alias)
}
func (r *StreamRoute) URL() url.URL {
@ -88,14 +87,12 @@ func (r *StreamRoute) Start() E.NestedError {
return nil
}
r.task, r.cancel = common.NewTaskWithCancel(r.String())
r.wg.Wait()
if err := r.Setup(); err != nil {
return E.FailWith("setup", err)
}
r.done = make(chan struct{})
r.l.Infof("listening on port %d", r.Port.ListeningPort)
r.wg.Add(2)
go r.acceptConnections()
go r.handleConnections()
if !r.Healthcheck.Disabled {
r.HealthMon = health.NewRawHealthMonitor(r.task, r.URL(), r.Healthcheck)
r.HealthMon.Start()
@ -122,11 +119,7 @@ func (r *StreamRoute) Stop() E.NestedError {
r.cancel()
r.CloseListeners()
r.wg.Wait()
r.task.Finished()
r.task, r.cancel = nil, nil
<-r.done
return nil
}
@ -135,41 +128,45 @@ func (r *StreamRoute) Started() bool {
}
func (r *StreamRoute) acceptConnections() {
defer r.wg.Done()
var connWg sync.WaitGroup
task := r.task.Subtask("%s accept connections", r.String())
defer func() {
connWg.Wait()
task.Finished()
r.task.Finished()
r.task, r.cancel = nil, nil
close(r.done)
r.done = nil
}()
for {
select {
case <-r.task.Context().Done():
case <-task.Context().Done():
return
default:
conn, err := r.Accept()
if err != nil {
select {
case <-r.task.Context().Done():
case <-task.Context().Done():
return
default:
var nErr *net.OpError
ok := errors.As(err, &nErr)
if !(ok && nErr.Timeout()) {
r.l.Error(err)
}
continue
}
}
r.connCh <- conn
}
}
}
func (r *StreamRoute) handleConnections() {
defer r.wg.Done()
for {
select {
case <-r.task.Context().Done():
return
case conn := <-r.connCh:
connWg.Add(1)
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err)
}
connWg.Done()
}()
}
}

View file

@ -4,31 +4,25 @@ import (
"context"
"fmt"
"net"
"sync"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
const tcpDialTimeout = 5 * time.Second
type (
Pipes []U.BidirectionalPipe
TCPConnMap = F.Map[net.Conn, struct{}]
TCPRoute struct {
*StreamRoute
listener net.Listener
pipe Pipes
mu sync.Mutex
listener *net.TCPListener
}
)
func NewTCPRoute(base *StreamRoute) StreamImpl {
return &TCPRoute{
StreamRoute: base,
pipe: make(Pipes, 0),
}
return &TCPRoute{StreamRoute: base}
}
func (route *TCPRoute) Setup() error {
@ -38,11 +32,12 @@ func (route *TCPRoute) Setup() error {
}
//! this read the allocated port from original ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in
route.listener = in.(*net.TCPListener)
return nil
}
func (route *TCPRoute) Accept() (any, error) {
route.listener.SetDeadline(time.Now().Add(time.Second))
return route.listener.Accept()
}
@ -50,24 +45,23 @@ func (route *TCPRoute) Handle(c any) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
go func() {
<-route.task.Context().Done()
clientConn.Close()
}()
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
defer cancel()
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
cancel()
if err != nil {
return err
}
route.mu.Lock()
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
route.pipe = append(route.pipe, pipe)
route.mu.Unlock()
return pipe.Start()
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
@ -67,6 +68,7 @@ func (route *UDPRoute) Accept() (any, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second))
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err

View file

@ -84,7 +84,7 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider,
http: httpSer,
https: httpsSer,
task: common.GlobalTask("Server " + opt.Name),
task: common.GlobalTask(opt.Name + " server"),
}
}
@ -133,13 +133,11 @@ func (s *Server) stop() {
if s.http != nil && s.httpStarted {
s.handleErr("http", s.http.Shutdown(ctx))
s.httpStarted = false
logger.Debugf("HTTP server %q stopped", s.Name)
}
if s.https != nil && s.httpsStarted {
s.handleErr("https", s.https.Shutdown(ctx))
s.httpsStarted = false
logger.Debugf("HTTPS server %q stopped", s.Name)
}
}
@ -152,7 +150,7 @@ func (s *Server) handleErr(scheme string, err error) {
case err == nil, errors.Is(err, http.ErrServerClosed):
return
default:
logrus.Fatalf("failed to start %s %s server: %s", scheme, s.Name, err)
logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err)
}
}

View file

@ -22,9 +22,9 @@ type (
// raw entry object before validation
// loaded from docker labels or yaml file
Alias string `json:"-" yaml:"-"`
Scheme string `json:"scheme" yaml:"scheme"`
Host string `json:"host" yaml:"host"`
Port string `json:"port" yaml:"port"`
Scheme string `json:"scheme,omitempty" yaml:"scheme"`
Host string `json:"host,omitempty" yaml:"host"`
Port string `json:"port,omitempty" yaml:"port"`
NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only
PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only
HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"`
@ -33,7 +33,7 @@ type (
Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"`
/* Docker only */
*docker.Container `json:"container" yaml:"-"`
Container *docker.Container `json:"container,omitempty" yaml:"-"`
}
RawEntries = F.Map[string, *RawEntry]
@ -43,16 +43,17 @@ var NewProxyEntries = F.NewMapOf[string, *RawEntry]
func (e *RawEntry) FillMissingFields() {
isDocker := e.Container != nil
cont := e.Container
if !isDocker {
e.Container = &docker.Container{}
cont = docker.DummyContainer
}
if e.Host == "" {
switch {
case e.PrivateIP != "":
e.Host = e.PrivateIP
case e.PublicIP != "":
e.Host = e.PublicIP
case cont.PrivateIP != "":
e.Host = cont.PrivateIP
case cont.PublicIP != "":
e.Host = cont.PublicIP
case !isDocker:
e.Host = "localhost"
}
@ -60,14 +61,14 @@ func (e *RawEntry) FillMissingFields() {
lp, pp, extra := e.splitPorts()
if port, ok := common.ServiceNamePortMapTCP[e.ImageName]; ok {
if port, ok := common.ServiceNamePortMapTCP[cont.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
if e.Scheme == "" {
e.Scheme = "tcp"
}
} else if port, ok := common.ImageNamePortMap[e.ImageName]; ok {
} else if port, ok := common.ImageNamePortMap[cont.ImageName]; ok {
if pp == "" {
pp = strconv.Itoa(port)
}
@ -77,9 +78,9 @@ func (e *RawEntry) FillMissingFields() {
} else if pp == "" && e.Scheme == "https" {
pp = "443"
} else if pp == "" {
if p := lowestPort(e.PrivatePortMapping); p != "" {
if p := lowestPort(cont.PrivatePortMapping); p != "" {
pp = p
} else if p := lowestPort(e.PublicPortMapping); p != "" {
} else if p := lowestPort(cont.PublicPortMapping); p != "" {
pp = p
} else if !isDocker {
pp = "80"
@ -89,23 +90,23 @@ func (e *RawEntry) FillMissingFields() {
}
// replace private port with public port if using public IP.
if e.Host == e.PublicIP {
if p, ok := e.PrivatePortMapping[pp]; ok {
if e.Host == cont.PublicIP {
if p, ok := cont.PrivatePortMapping[pp]; ok {
pp = U.PortString(p.PublicPort)
}
}
// replace public port with private port if using private IP.
if e.Host == e.PrivateIP {
if p, ok := e.PublicPortMapping[pp]; ok {
if e.Host == cont.PrivateIP {
if p, ok := cont.PublicPortMapping[pp]; ok {
pp = U.PortString(p.PrivatePort)
}
}
if e.Scheme == "" && isDocker {
switch {
case e.Host == e.PublicIP && e.PublicPortMapping[pp].Type == "udp":
case e.Host == cont.PublicIP && cont.PublicPortMapping[pp].Type == "udp":
e.Scheme = "udp"
case e.Host == e.PrivateIP && e.PrivatePortMapping[pp].Type == "udp":
case e.Host == cont.PrivateIP && cont.PrivatePortMapping[pp].Type == "udp":
e.Scheme = "udp"
}
}
@ -127,17 +128,17 @@ func (e *RawEntry) FillMissingFields() {
if e.HealthCheck.Timeout == 0 {
e.HealthCheck.Timeout = common.HealthCheckTimeoutDefault
}
if e.IdleTimeout == "" {
e.IdleTimeout = common.IdleTimeoutDefault
if cont.IdleTimeout == "" {
cont.IdleTimeout = common.IdleTimeoutDefault
}
if e.WakeTimeout == "" {
e.WakeTimeout = common.WakeTimeoutDefault
if cont.WakeTimeout == "" {
cont.WakeTimeout = common.WakeTimeoutDefault
}
if e.StopTimeout == "" {
e.StopTimeout = common.StopTimeoutDefault
if cont.StopTimeout == "" {
cont.StopTimeout = common.StopTimeoutDefault
}
if e.StopMethod == "" {
e.StopMethod = common.StopMethodDefault
if cont.StopMethod == "" {
cont.StopMethod = common.StopMethodDefault
}
e.Port = joinPorts(lp, pp, extra)

View file

@ -99,9 +99,53 @@ func (p BidirectionalPipe) Start() error {
return b.Build().Error()
}
func Copy(dst *ContextWriter, src *ContextReader) error {
_, err := io.Copy(dst, src)
return err
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// This is a copy of io.Copy with context handling
// Author: yusing <yusing@6uo.me>
func Copy(dst *ContextWriter, src *ContextReader) (err error) {
size := 32 * 1024
if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N {
if l.N < 1 {
size = 1
} else {
size = int(l.N)
}
}
buf := make([]byte, size)
for {
select {
case <-src.ctx.Done():
return src.ctx.Err()
case <-dst.ctx.Done():
return dst.ctx.Err()
default:
nr, er := src.Reader.Read(buf)
if nr > 0 {
nw, ew := dst.Writer.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write result")
}
}
if ew != nil {
err = ew
return
}
if nr != nw {
err = io.ErrShortWrite
return
}
}
if er != nil {
if er != io.EOF {
err = er
}
return
}
}
}
}
func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error {

View file

@ -99,12 +99,10 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
continue // Ignore this field if the tag is "-"
}
if strings.Contains(jsonTag, ",omitempty") {
if field.Type.Kind() == reflect.Ptr && value.Field(i).IsNil() {
continue
}
if value.Field(i).IsZero() {
continue
}
jsonTag = strings.Replace(jsonTag, ",omitempty", "", 1)
}
// If the json tag is not empty, use it as the key

View file

@ -7,7 +7,7 @@ import (
)
type HealthCheckConfig struct {
Disabled bool `json:"disabled" yaml:"disabled"`
Disabled bool `json:"disabled,omitempty" yaml:"disabled"`
Path string `json:"path,omitempty" yaml:"path"`
UseGet bool `json:"use_get,omitempty" yaml:"use_get"`
Interval time.Duration `json:"interval" yaml:"interval"`

View file

@ -0,0 +1,30 @@
package health
import (
"encoding/json"
"time"
"github.com/yusing/go-proxy/internal/net/types"
)
type JSONRepresentation struct {
Name string
Config *HealthCheckConfig
Status Status
Started time.Time
Uptime time.Duration
URL types.URL
Extra map[string]any
}
func (jsonRepr *JSONRepresentation) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"name": jsonRepr.Name,
"config": jsonRepr.Config,
"started": jsonRepr.Started.Unix(),
"status": jsonRepr.Status.String(),
"uptime": jsonRepr.Uptime.Seconds(),
"url": jsonRepr.URL.String(),
"extra": jsonRepr.Extra,
})
}

View file

@ -2,7 +2,6 @@ package health
import (
"context"
"encoding/json"
"errors"
"sync"
"time"
@ -25,6 +24,7 @@ type (
}
HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct {
service string
config *HealthCheckConfig
url types.URL
@ -43,8 +43,10 @@ type (
var monMap = F.NewMapOf[string, HealthMonitor]()
func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor {
task, cancel := task.SubtaskWithCancel("Health monitor for %s", task.Name())
service := task.Name()
task, cancel := task.SubtaskWithCancel("Health monitor for %s", service)
mon := &monitor{
service: service,
config: config,
url: url,
checkHealth: healthCheckFunc,
@ -57,17 +59,12 @@ func newMonitor(task common.Task, url types.URL, config *HealthCheckConfig, heal
return mon
}
func Inspect(name string) (status Status, ok bool) {
mon, ok := monMap.Load(name)
if !ok {
return
}
return mon.Status(), true
func Inspect(name string) (HealthMonitor, bool) {
return monMap.Load(name)
}
func (mon *monitor) Start() {
defer monMap.Store(mon.task.Name(), mon)
defer logger.Debugf("%s health monitor started", mon.String())
go func() {
defer close(mon.done)
@ -93,12 +90,9 @@ func (mon *monitor) Start() {
}
}
}()
logger.Debugf("health monitor %q started", mon.String())
}
func (mon *monitor) Stop() {
defer logger.Debugf("%s health monitor stopped", mon.String())
monMap.Delete(mon.task.Name())
mon.mu.Lock()
@ -132,14 +126,14 @@ func (mon *monitor) String() string {
}
func (mon *monitor) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"name": mon.Name(),
"url": mon.url,
"status": mon.status.Load(),
"uptime": mon.Uptime().String(),
"started": mon.startTime.Unix(),
"config": mon.config,
})
return (&JSONRepresentation{
Name: mon.Name(),
Config: mon.config,
Status: mon.status.Load(),
Started: mon.startTime,
Uptime: mon.Uptime(),
URL: mon.url,
}).MarshalJSON()
}
func (mon *monitor) checkUpdateHealth() (hasError bool) {
@ -147,7 +141,7 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
if err != nil {
mon.status.Store(StatusError)
if !errors.Is(err, context.Canceled) {
logger.Errorf("%s failed to check health: %s", mon.String(), err)
logger.Errorf("%s failed to check health: %s", mon.service, err)
}
mon.Stop()
return false
@ -160,9 +154,9 @@ func (mon *monitor) checkUpdateHealth() (hasError bool) {
}
if healthy != (mon.status.Swap(status) == StatusHealthy) {
if healthy {
logger.Infof("%s is up", mon.String())
logger.Infof("%s is up", mon.service)
} else {
logger.Warnf("%s is down: %s", mon.String(), detail)
logger.Warnf("%s is down: %s", mon.service, detail)
}
}

View file

@ -1,7 +1,5 @@
package health
import "encoding/json"
type Status int
const (
@ -36,7 +34,7 @@ func (s Status) String() string {
}
func (s Status) MarshalJSON() ([]byte, error) {
return json.Marshal(s.String())
return []byte(`"` + s.String() + `"`), nil
}
func (s Status) Good() bool {