mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fixed loadbalancer with idlewatcher, fixed reload issue
This commit is contained in:
parent
01ffe0d97c
commit
a278711421
78 changed files with 906 additions and 609 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -24,3 +24,4 @@ todo.md
|
|||
.aider*
|
||||
mtrace.json
|
||||
.env
|
||||
test.Dockerfile
|
||||
|
|
4
Makefile
4
Makefile
|
@ -55,7 +55,7 @@ repush:
|
|||
git push gitlab dev --force
|
||||
|
||||
rapid-crash:
|
||||
sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\
|
||||
sudo docker run --restart=always --name test_crash -p 80 debian:bookworm-slim /bin/cat &&\
|
||||
sleep 3 &&\
|
||||
sudo docker rm -f test_crash
|
||||
|
||||
|
@ -64,4 +64,4 @@ debug-list-containers:
|
|||
|
||||
ci-test:
|
||||
mkdir -p /tmp/artifacts
|
||||
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
|
||||
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
|
||||
|
|
|
@ -84,7 +84,7 @@ func main() {
|
|||
middleware.LoadComposeFiles()
|
||||
|
||||
var cfg *config.Config
|
||||
var err E.NestedError
|
||||
var err E.Error
|
||||
if cfg, err = config.Load(); err != nil {
|
||||
logrus.Warn(err)
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var validateErr E.NestedError
|
||||
var validateErr E.Error
|
||||
if filename == common.ConfigFileName {
|
||||
validateErr = config.Validate(content)
|
||||
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
)
|
||||
|
||||
func ReloadServer() E.NestedError {
|
||||
func ReloadServer() E.Error {
|
||||
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
|
@ -34,7 +34,7 @@ func ReloadServer() E.NestedError {
|
|||
return nil
|
||||
}
|
||||
|
||||
func List[T any](what string) (_ T, outErr E.NestedError) {
|
||||
func List[T any](what string) (_ T, outErr E.Error) {
|
||||
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
|
||||
if err != nil {
|
||||
outErr = E.From(err)
|
||||
|
@ -54,14 +54,14 @@ func List[T any](what string) (_ T, outErr E.NestedError) {
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func ListRoutes() (map[string]map[string]any, E.NestedError) {
|
||||
func ListRoutes() (map[string]map[string]any, E.Error) {
|
||||
return List[map[string]map[string]any](v1.ListRoutes)
|
||||
}
|
||||
|
||||
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
|
||||
func ListMiddlewareTraces() (middleware.Traces, E.Error) {
|
||||
return List[middleware.Traces](v1.ListMiddlewareTraces)
|
||||
}
|
||||
|
||||
func DebugListTasks() (map[string]any, E.NestedError) {
|
||||
func DebugListTasks() (map[string]any, E.Error) {
|
||||
return List[map[string]any](v1.ListTasks)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
|
|||
return (*Config)(cfg)
|
||||
}
|
||||
|
||||
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
|
||||
func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
|
||||
b := E.NewBuilder("unable to initialize autocert")
|
||||
defer b.To(&res)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ type (
|
|||
tlsCert *tls.Certificate
|
||||
certExpiries CertExpiries
|
||||
}
|
||||
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError)
|
||||
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
|
||||
|
||||
CertExpiries map[string]time.Time
|
||||
)
|
||||
|
@ -57,7 +57,7 @@ func (p *Provider) GetExpiries() CertExpiries {
|
|||
return p.certExpiries
|
||||
}
|
||||
|
||||
func (p *Provider) ObtainCert() (res E.NestedError) {
|
||||
func (p *Provider) ObtainCert() (res E.Error) {
|
||||
b := E.NewBuilder("failed to obtain certificate")
|
||||
defer b.To(&res)
|
||||
|
||||
|
@ -112,7 +112,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) LoadCert() E.NestedError {
|
||||
func (p *Provider) LoadCert() E.Error {
|
||||
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
|
||||
if err.HasError() {
|
||||
return err
|
||||
|
@ -158,7 +158,7 @@ func (p *Provider) ScheduleRenewal() {
|
|||
}()
|
||||
}
|
||||
|
||||
func (p *Provider) initClient() E.NestedError {
|
||||
func (p *Provider) initClient() E.Error {
|
||||
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
|
||||
if err.HasError() {
|
||||
return E.FailWith("create lego client", err)
|
||||
|
@ -178,7 +178,7 @@ func (p *Provider) initClient() E.NestedError {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) registerACME() E.NestedError {
|
||||
func (p *Provider) registerACME() E.Error {
|
||||
if p.user.Registration != nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ func (p *Provider) registerACME() E.NestedError {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) E.Error {
|
||||
/* This should have been done in setup
|
||||
but double check is always a good choice.*/
|
||||
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||
|
@ -239,7 +239,7 @@ func (p *Provider) certState() CertState {
|
|||
return CertStateValid
|
||||
}
|
||||
|
||||
func (p *Provider) renewIfNeeded() E.NestedError {
|
||||
func (p *Provider) renewIfNeeded() E.Error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
@ -259,7 +259,7 @@ func (p *Provider) renewIfNeeded() E.NestedError {
|
|||
return nil
|
||||
}
|
||||
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) {
|
||||
r := make(CertExpiries, len(cert.Certificate))
|
||||
for _, cert := range cert.Certificate {
|
||||
x509Cert, err := E.Check(x509.ParseCertificate(cert))
|
||||
|
@ -281,7 +281,7 @@ func providerGenerator[CT any, PT challenge.Provider](
|
|||
defaultCfg func() *CT,
|
||||
newProvider func(*CT) (PT, error),
|
||||
) ProviderGenerator {
|
||||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
|
||||
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
|
||||
cfg := defaultCfg()
|
||||
err := U.Deserialize(opt, cfg)
|
||||
if err.HasError() {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func (p *Provider) Setup() (err E.NestedError) {
|
||||
func (p *Provider) Setup() (err E.Error) {
|
||||
if err = p.LoadCert(); err != nil {
|
||||
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
|
||||
return err
|
||||
|
|
|
@ -47,3 +47,5 @@ const (
|
|||
StopTimeoutDefault = "10s"
|
||||
StopMethodDefault = "stop"
|
||||
)
|
||||
|
||||
const HeaderCheckRedirect = "X-Goproxy-Check-Redirect"
|
||||
|
|
|
@ -55,7 +55,7 @@ func newConfig() *Config {
|
|||
}
|
||||
}
|
||||
|
||||
func Load() (*Config, E.NestedError) {
|
||||
func Load() (*Config, E.Error) {
|
||||
if instance != nil {
|
||||
return instance, nil
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func Load() (*Config, E.NestedError) {
|
|||
return instance, instance.load()
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.NestedError {
|
||||
func Validate(data []byte) E.Error {
|
||||
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
|
||||
}
|
||||
|
||||
|
@ -78,7 +78,7 @@ func WatchChanges() {
|
|||
task,
|
||||
configEventFlushInterval,
|
||||
OnConfigChange,
|
||||
func(err E.NestedError) {
|
||||
func(err E.Error) {
|
||||
logger.Error(err)
|
||||
},
|
||||
)
|
||||
|
@ -104,7 +104,7 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) {
|
|||
}
|
||||
}
|
||||
|
||||
func Reload() E.NestedError {
|
||||
func Reload() E.Error {
|
||||
// avoid race between config change and API reload request
|
||||
reloadMu.Lock()
|
||||
defer reloadMu.Unlock()
|
||||
|
@ -139,7 +139,7 @@ func (cfg *Config) Task() task.Task {
|
|||
func (cfg *Config) StartProxyProviders() {
|
||||
b := E.NewBuilder("errors starting providers")
|
||||
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
|
||||
b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName())))
|
||||
b.Add(p.Start(cfg.task.Subtask(p.String())))
|
||||
})
|
||||
|
||||
if b.HasError() {
|
||||
|
@ -147,7 +147,7 @@ func (cfg *Config) StartProxyProviders() {
|
|||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) load() (res E.NestedError) {
|
||||
func (cfg *Config) load() (res E.Error) {
|
||||
b := E.NewBuilder("errors loading config")
|
||||
defer b.To(&res)
|
||||
|
||||
|
@ -182,7 +182,7 @@ func (cfg *Config) load() (res E.NestedError) {
|
|||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.NestedError) {
|
||||
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
|
||||
if cfg.autocertProvider != nil {
|
||||
return
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
|
|||
return
|
||||
}
|
||||
|
||||
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) {
|
||||
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.Error) {
|
||||
subtask := cfg.task.Subtask("load providers")
|
||||
defer subtask.Finish("done")
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
task.GlobalTask("close docker clients").OnComplete("", func() {
|
||||
task.GlobalTask("close docker clients").OnFinished("", func() {
|
||||
clientMap.RangeAllParallel(func(_ string, c Client) {
|
||||
if c.Connected() {
|
||||
c.Client.Close()
|
||||
|
@ -70,7 +70,7 @@ func (c *SharedClient) Close() error {
|
|||
// Returns:
|
||||
// - Client: the Docker client connection.
|
||||
// - error: an error if the connection failed.
|
||||
func ConnectClient(host string) (Client, E.NestedError) {
|
||||
func ConnectClient(host string) (Client, E.Error) {
|
||||
clientMapMu.Lock()
|
||||
defer clientMapMu.Unlock()
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
type templateData struct {
|
||||
|
@ -18,17 +20,15 @@ type templateData struct {
|
|||
var loadingPage []byte
|
||||
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
|
||||
|
||||
const headerCheckRedirect = "X-Goproxy-Check-Redirect"
|
||||
|
||||
func (w *Watcher) makeLoadingPageBody() []byte {
|
||||
msg := fmt.Sprintf("%s is starting...", w.ContainerName)
|
||||
|
||||
data := new(templateData)
|
||||
data.CheckRedirectHeader = headerCheckRedirect
|
||||
data.CheckRedirectHeader = common.HeaderCheckRedirect
|
||||
data.Title = w.ContainerName
|
||||
data.Message = strings.ReplaceAll(msg, " ", " ")
|
||||
|
||||
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect)))
|
||||
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(common.HeaderCheckRedirect)))
|
||||
err := loadingPageTmpl.Execute(buf, data)
|
||||
if err != nil { // should never happen in production
|
||||
panic(err)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package idlewatcher
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
@ -30,7 +30,7 @@ const (
|
|||
StopMethodKill StopMethod = "kill"
|
||||
)
|
||||
|
||||
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
|
||||
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
|
||||
if cont == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
|
||||
func validateDurationPostitive(value string) (time.Duration, E.Error) {
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return 0, E.Invalid("duration", value).With(err)
|
||||
|
@ -91,7 +91,7 @@ func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
|
|||
return d, nil
|
||||
}
|
||||
|
||||
func validateSignal(s string) (Signal, E.NestedError) {
|
||||
func validateSignal(s string) (Signal, E.Error) {
|
||||
switch s {
|
||||
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
|
||||
"INT", "TERM", "HUP", "QUIT":
|
||||
|
@ -101,7 +101,7 @@ func validateSignal(s string) (Signal, E.NestedError) {
|
|||
return "", E.Invalid("signal", s)
|
||||
}
|
||||
|
||||
func validateStopMethod(s string) (StopMethod, E.NestedError) {
|
||||
func validateStopMethod(s string) (StopMethod, E.Error) {
|
||||
sm := StopMethod(s)
|
||||
switch sm {
|
||||
case StopMethodPause, StopMethodStop, StopMethodKill:
|
14
internal/docker/idlewatcher/types/waker.go
Normal file
14
internal/docker/idlewatcher/types/waker.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type Waker interface {
|
||||
health.HealthMonitor
|
||||
http.Handler
|
||||
net.Stream
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
package idlewatcher
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
|
@ -14,12 +14,6 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
type Waker interface {
|
||||
health.HealthMonitor
|
||||
http.Handler
|
||||
net.Stream
|
||||
}
|
||||
|
||||
type waker struct {
|
||||
_ U.NoCopy
|
||||
|
||||
|
@ -37,7 +31,7 @@ const (
|
|||
|
||||
// TODO: support stream
|
||||
|
||||
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) {
|
||||
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
|
||||
hcCfg := entry.HealthCheckConfig()
|
||||
hcCfg.Timeout = idleWakerCheckTimeout
|
||||
|
||||
|
@ -62,24 +56,26 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
|
|||
}
|
||||
|
||||
// lifetime should follow route provider
|
||||
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) {
|
||||
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, rp, nil)
|
||||
}
|
||||
|
||||
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) {
|
||||
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.Error) {
|
||||
return newWaker(providerSubTask, entry, nil, stream)
|
||||
}
|
||||
|
||||
// Start implements health.HealthMonitor.
|
||||
func (w *Watcher) Start(routeSubTask task.Task) E.NestedError {
|
||||
w.task.OnComplete("stop route", func() {
|
||||
routeSubTask.Parent().Finish("watcher stopped")
|
||||
func (w *Watcher) Start(routeSubTask task.Task) E.Error {
|
||||
routeSubTask.Finish("ignored")
|
||||
w.task.OnCancel("stop route", func() {
|
||||
routeSubTask.Parent().Finish(w.task.FinishCause())
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finish implements health.HealthMonitor.
|
||||
func (w *Watcher) Finish(reason string) {}
|
||||
func (w *Watcher) Finish(reason any) {
|
||||
}
|
||||
|
||||
// Name implements health.HealthMonitor.
|
||||
func (w *Watcher) Name() string {
|
||||
|
@ -109,6 +105,7 @@ func (w *Watcher) Status() health.Status {
|
|||
healthy, _, err := w.hc.CheckHealth()
|
||||
switch {
|
||||
case err != nil:
|
||||
w.ready.Store(false)
|
||||
return health.StatusError
|
||||
case healthy:
|
||||
w.ready.Store(true)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
|
@ -37,7 +38,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
accept := gphttp.GetAccept(r.Header)
|
||||
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
|
||||
|
||||
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
|
||||
isCheckRedirect := r.Header.Get(common.HeaderCheckRedirect) != ""
|
||||
if !isCheckRedirect && acceptHTML {
|
||||
// Send a loading response to the client
|
||||
body := w.makeLoadingPageBody()
|
||||
|
@ -56,14 +57,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
|
||||
defer cancel()
|
||||
|
||||
checkCancelled := func() bool {
|
||||
checkCanceled := func() bool {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context()))
|
||||
w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context()))
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
w.l.Debugf("wake cancelled: %s", context.Cause(ctx))
|
||||
w.l.Debugf("wake canceled: %s", context.Cause(ctx))
|
||||
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
|
||||
return true
|
||||
default:
|
||||
|
@ -71,7 +72,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
}
|
||||
}
|
||||
|
||||
if checkCancelled() {
|
||||
if checkCanceled() {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -84,14 +85,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
}
|
||||
|
||||
for {
|
||||
if checkCancelled() {
|
||||
if checkCanceled() {
|
||||
return false
|
||||
}
|
||||
|
||||
if w.Status() == health.StatusHealthy {
|
||||
w.resetIdleTimer()
|
||||
if isCheckRedirect {
|
||||
logrus.Debugf("container %s is ready, redirecting...", w.String())
|
||||
logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -8,44 +8,47 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health"
|
||||
)
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Addr() net.Addr {
|
||||
return w.stream.Addr()
|
||||
}
|
||||
|
||||
func (w *Watcher) Setup() error {
|
||||
return w.stream.Setup()
|
||||
}
|
||||
|
||||
// Accept implements types.Stream.
|
||||
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
|
||||
func (w *Watcher) Accept() (conn net.Conn, err error) {
|
||||
conn, err = w.stream.Accept()
|
||||
// timeout means no connection is accepted
|
||||
var nErr *net.OpError
|
||||
ok := errors.As(err, &nErr)
|
||||
if ok && nErr.Timeout() {
|
||||
if err != nil {
|
||||
logrus.Errorf("accept failed with error: %s", err)
|
||||
return
|
||||
}
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return nil, err
|
||||
w.l.Error(err)
|
||||
}
|
||||
return w.stream.Accept()
|
||||
}
|
||||
|
||||
// CloseListeners implements types.Stream.
|
||||
func (w *Watcher) CloseListeners() {
|
||||
w.stream.CloseListeners()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle implements types.Stream.
|
||||
func (w *Watcher) Handle(conn types.StreamConn) error {
|
||||
func (w *Watcher) Handle(conn net.Conn) error {
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.stream.Handle(conn)
|
||||
}
|
||||
|
||||
// Close implements types.Stream.
|
||||
func (w *Watcher) Close() error {
|
||||
return w.stream.Close()
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromStream() error {
|
||||
w.resetIdleTimer()
|
||||
|
||||
// pass through if container is already ready
|
||||
if w.ready.Load() {
|
||||
return nil
|
||||
|
@ -66,11 +69,11 @@ func (w *Watcher) wakeFromStream() error {
|
|||
select {
|
||||
case <-w.task.Context().Done():
|
||||
cause := w.task.FinishCause()
|
||||
w.l.Debugf("wake cancelled: %s", cause)
|
||||
w.l.Debugf("wake canceled: %s", cause)
|
||||
return cause
|
||||
case <-ctx.Done():
|
||||
cause := context.Cause(ctx)
|
||||
w.l.Debugf("wake cancelled: %s", cause)
|
||||
w.l.Debugf("wake canceled: %s", cause)
|
||||
return cause
|
||||
default:
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -49,7 +49,7 @@ var (
|
|||
|
||||
const dockerReqTimeout = 3 * time.Second
|
||||
|
||||
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) {
|
||||
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) {
|
||||
failure := E.Failure("idle_watcher register")
|
||||
cfg := entry.IdlewatcherConfig()
|
||||
|
||||
|
@ -66,6 +66,7 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
|
|||
w.Config = cfg
|
||||
w.waker = waker
|
||||
w.resetIdleTimer()
|
||||
providerSubtask.Finish("used existing watcher")
|
||||
return w, nil
|
||||
}
|
||||
|
||||
|
@ -88,13 +89,11 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
|
|||
go func() {
|
||||
cause := w.watchUntilDestroy()
|
||||
|
||||
watcherMapMu.Lock()
|
||||
watcherMap.Delete(w.ContainerID)
|
||||
watcherMapMu.Unlock()
|
||||
|
||||
w.ticker.Stop()
|
||||
w.client.Close()
|
||||
w.task.Finish(cause.Error())
|
||||
w.task.Finish(cause)
|
||||
}()
|
||||
|
||||
return w, nil
|
||||
|
@ -146,7 +145,7 @@ func (w *Watcher) wakeIfStopped() error {
|
|||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), w.WakeTimeout)
|
||||
defer cancel()
|
||||
|
||||
// !Hard coded here since theres no constants from Docker API
|
||||
|
@ -175,7 +174,7 @@ func (w *Watcher) getStopCallback() StopCallback {
|
|||
panic("should not reach here")
|
||||
}
|
||||
return func() error {
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
|
||||
ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.StopTimeout)*time.Second)
|
||||
defer cancel()
|
||||
return cb(ctx)
|
||||
}
|
||||
|
@ -186,8 +185,8 @@ func (w *Watcher) resetIdleTimer() {
|
|||
w.ticker.Reset(w.IdleTimeout)
|
||||
}
|
||||
|
||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) {
|
||||
eventTask = w.task.Subtask("watcher for %s", w.ContainerID)
|
||||
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
|
||||
eventTask = w.task.Subtask("docker event watcher")
|
||||
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
|
||||
Filters: W.NewDockerFilter(
|
||||
W.DockerFilterContainer,
|
||||
|
@ -218,13 +217,12 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
|
|||
func (w *Watcher) watchUntilDestroy() error {
|
||||
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
|
||||
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
|
||||
defer eventTask.Finish("stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.task.Context().Done():
|
||||
cause := context.Cause(w.task.Context())
|
||||
w.l.Debugf("watcher stopped by context done: %s", cause)
|
||||
return cause
|
||||
return w.task.FinishCause()
|
||||
case err := <-dockerEventErrCh:
|
||||
if err != nil && err.IsNot(context.Canceled) {
|
||||
w.l.Error(E.FailWith("docker watcher", err))
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) {
|
||||
func Inspect(dockerHost string, containerID string) (*Container, E.Error) {
|
||||
client, err := ConnectClient(dockerHost)
|
||||
defer client.Close()
|
||||
|
||||
|
@ -19,7 +19,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError)
|
|||
return client.Inspect(containerID)
|
||||
}
|
||||
|
||||
func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
|
||||
func (c Client) Inspect(containerID string) (*Container, E.Error) {
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
|
||||
defer cancel()
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ func (l *Label) String() string {
|
|||
//
|
||||
// Returns:
|
||||
// - error: an error if the field does not exist.
|
||||
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
||||
func ApplyLabel[T any](obj *T, l *Label) E.Error {
|
||||
if obj == nil {
|
||||
return E.Invalid("nil object", l)
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
|
|||
}
|
||||
}
|
||||
|
||||
func ParseLabel(label string, value string) (*Label, E.NestedError) {
|
||||
func ParseLabel(label string, value string) (*Label, E.Error) {
|
||||
parts := strings.Split(label, ".")
|
||||
|
||||
if len(parts) < 2 {
|
||||
|
|
|
@ -11,11 +11,6 @@ import (
|
|||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type ClientInfo struct {
|
||||
Client Client
|
||||
Containers []types.Container
|
||||
}
|
||||
|
||||
var listOptions = container.ListOptions{
|
||||
// created|restarting|running|removing|paused|exited|dead
|
||||
// Filters: filters.NewArgs(
|
||||
|
@ -28,28 +23,21 @@ var listOptions = container.ListOptions{
|
|||
All: true,
|
||||
}
|
||||
|
||||
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
|
||||
func ListContainers(clientHost string) ([]types.Container, E.Error) {
|
||||
dockerClient, err := ConnectClient(clientHost)
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("connect to docker", err)
|
||||
}
|
||||
defer dockerClient.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout"))
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
|
||||
defer cancel()
|
||||
|
||||
var containers []types.Container
|
||||
if getContainer {
|
||||
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("list containers", err)
|
||||
}
|
||||
containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions))
|
||||
if err.HasError() {
|
||||
return nil, E.FailWith("list containers", err)
|
||||
}
|
||||
|
||||
return &ClientInfo{
|
||||
Client: dockerClient,
|
||||
Containers: containers,
|
||||
}, nil
|
||||
return containers, nil
|
||||
}
|
||||
|
||||
func IsErrConnectionFailed(err error) bool {
|
|
@ -12,7 +12,7 @@ type Builder struct {
|
|||
|
||||
type builder struct {
|
||||
message string
|
||||
errors []NestedError
|
||||
errors []Error
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ func NewBuilder(format string, args ...any) Builder {
|
|||
// adding nil is no-op,
|
||||
//
|
||||
// flatten is a boolean flag to flatten the NestedError.
|
||||
func (b Builder) Add(err NestedError, flatten ...bool) {
|
||||
func (b Builder) Add(err Error, flatten ...bool) {
|
||||
if err != nil {
|
||||
b.Lock()
|
||||
if len(flatten) > 0 && flatten[0] {
|
||||
|
@ -54,7 +54,7 @@ func (b Builder) Addf(format string, args ...any) {
|
|||
}
|
||||
}
|
||||
|
||||
func (b Builder) AddRange(errs ...NestedError) {
|
||||
func (b Builder) AddRange(errs ...Error) {
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
for _, err := range errs {
|
||||
|
@ -77,14 +77,14 @@ func (b Builder) AddRangeE(errs ...error) {
|
|||
//
|
||||
// Returns:
|
||||
// - NestedError: the built NestedError.
|
||||
func (b Builder) Build() NestedError {
|
||||
func (b Builder) Build() Error {
|
||||
if len(b.errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return Join(b.message, b.errors...)
|
||||
}
|
||||
|
||||
func (b Builder) To(ptr *NestedError) {
|
||||
func (b Builder) To(ptr *Error) {
|
||||
switch {
|
||||
case ptr == nil:
|
||||
return
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestBuilderEmpty(t *testing.T) {
|
|||
|
||||
func TestBuilderAddNil(t *testing.T) {
|
||||
eb := NewBuilder("asdf")
|
||||
var err NestedError
|
||||
var err Error
|
||||
for range 3 {
|
||||
eb.Add(nil)
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ func TestBuilderTo(t *testing.T) {
|
|||
eb := NewBuilder("error occurred")
|
||||
eb.Addf("abcd")
|
||||
|
||||
var err NestedError
|
||||
var err Error
|
||||
eb.To(&err)
|
||||
got := err.String()
|
||||
expected := (`error occurred:
|
||||
|
|
|
@ -8,35 +8,35 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
NestedError = *NestedErrorImpl
|
||||
NestedErrorImpl struct {
|
||||
Error = *ErrorImpl
|
||||
ErrorImpl struct {
|
||||
subject string
|
||||
err error
|
||||
extras []NestedErrorImpl
|
||||
extras []ErrorImpl
|
||||
}
|
||||
JSONNestedError struct {
|
||||
Subject string `json:"subject"`
|
||||
Err string `json:"error"`
|
||||
Extras []JSONNestedError `json:"extras,omitempty"`
|
||||
ErrorJSONMarshaller struct {
|
||||
Subject string `json:"subject"`
|
||||
Err string `json:"error"`
|
||||
Extras []ErrorJSONMarshaller `json:"extras,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
func From(err error) NestedError {
|
||||
func From(err error) Error {
|
||||
if IsNil(err) {
|
||||
return nil
|
||||
}
|
||||
return &NestedErrorImpl{err: err}
|
||||
return &ErrorImpl{err: err}
|
||||
}
|
||||
|
||||
func FromJSON(data []byte) (NestedError, bool) {
|
||||
var j JSONNestedError
|
||||
func FromJSON(data []byte) (Error, bool) {
|
||||
var j ErrorJSONMarshaller
|
||||
if err := json.Unmarshal(data, &j); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if j.Err == "" {
|
||||
return nil, false
|
||||
}
|
||||
extras := make([]NestedErrorImpl, len(j.Extras))
|
||||
extras := make([]ErrorImpl, len(j.Extras))
|
||||
for i, e := range j.Extras {
|
||||
extra, ok := fromJSONObject(e)
|
||||
if !ok {
|
||||
|
@ -44,7 +44,7 @@ func FromJSON(data []byte) (NestedError, bool) {
|
|||
}
|
||||
extras[i] = *extra
|
||||
}
|
||||
return &NestedErrorImpl{
|
||||
return &ErrorImpl{
|
||||
subject: j.Subject,
|
||||
err: errors.New(j.Err),
|
||||
extras: extras,
|
||||
|
@ -53,12 +53,12 @@ func FromJSON(data []byte) (NestedError, bool) {
|
|||
|
||||
// Check is a helper function that
|
||||
// convert (T, error) to (T, NestedError).
|
||||
func Check[T any](obj T, err error) (T, NestedError) {
|
||||
func Check[T any](obj T, err error) (T, Error) {
|
||||
return obj, From(err)
|
||||
}
|
||||
|
||||
func Join(message string, err ...NestedError) NestedError {
|
||||
extras := make([]NestedErrorImpl, len(err))
|
||||
func Join(message string, err ...Error) Error {
|
||||
extras := make([]ErrorImpl, len(err))
|
||||
nErr := 0
|
||||
for i, e := range err {
|
||||
if e == nil {
|
||||
|
@ -70,13 +70,13 @@ func Join(message string, err ...NestedError) NestedError {
|
|||
if nErr == 0 {
|
||||
return nil
|
||||
}
|
||||
return &NestedErrorImpl{
|
||||
return &ErrorImpl{
|
||||
err: errors.New(message),
|
||||
extras: extras,
|
||||
}
|
||||
}
|
||||
|
||||
func JoinE(message string, err ...error) NestedError {
|
||||
func JoinE(message string, err ...error) Error {
|
||||
b := NewBuilder("%s", message)
|
||||
for _, e := range err {
|
||||
b.AddE(e)
|
||||
|
@ -92,13 +92,13 @@ func IsNotNil(err error) bool {
|
|||
return err != nil
|
||||
}
|
||||
|
||||
func (ne NestedError) String() string {
|
||||
func (ne Error) String() string {
|
||||
var buf strings.Builder
|
||||
ne.writeToSB(&buf, 0, "")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (ne NestedError) Is(err error) bool {
|
||||
func (ne Error) Is(err error) bool {
|
||||
if ne == nil {
|
||||
return err == nil
|
||||
}
|
||||
|
@ -114,18 +114,18 @@ func (ne NestedError) Is(err error) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (ne NestedError) IsNot(err error) bool {
|
||||
func (ne Error) IsNot(err error) bool {
|
||||
return !ne.Is(err)
|
||||
}
|
||||
|
||||
func (ne NestedError) Error() error {
|
||||
func (ne Error) Error() error {
|
||||
if ne == nil {
|
||||
return nil
|
||||
}
|
||||
return ne.buildError(0, "")
|
||||
}
|
||||
|
||||
func (ne NestedError) With(s any) NestedError {
|
||||
func (ne Error) With(s any) Error {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ func (ne NestedError) With(s any) NestedError {
|
|||
switch ss := s.(type) {
|
||||
case nil:
|
||||
return ne
|
||||
case *NestedErrorImpl:
|
||||
case *ErrorImpl:
|
||||
if len(ss.extras) == 1 {
|
||||
ne.extras = append(ne.extras, ss.extras[0])
|
||||
return ne
|
||||
|
@ -151,11 +151,11 @@ func (ne NestedError) With(s any) NestedError {
|
|||
return ne.withError(From(errors.New(msg)))
|
||||
}
|
||||
|
||||
func (ne NestedError) Extraf(format string, args ...any) NestedError {
|
||||
func (ne Error) Extraf(format string, args ...any) Error {
|
||||
return ne.With(errorf(format, args...))
|
||||
}
|
||||
|
||||
func (ne NestedError) Subject(s any, sep ...string) NestedError {
|
||||
func (ne Error) Subject(s any, sep ...string) Error {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
|
@ -179,26 +179,26 @@ func (ne NestedError) Subject(s any, sep ...string) NestedError {
|
|||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
|
||||
func (ne Error) Subjectf(format string, args ...any) Error {
|
||||
if ne == nil {
|
||||
return ne
|
||||
}
|
||||
return ne.Subject(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (ne NestedError) JSONObject() JSONNestedError {
|
||||
extras := make([]JSONNestedError, len(ne.extras))
|
||||
func (ne Error) JSONObject() ErrorJSONMarshaller {
|
||||
extras := make([]ErrorJSONMarshaller, len(ne.extras))
|
||||
for i, e := range ne.extras {
|
||||
extras[i] = e.JSONObject()
|
||||
}
|
||||
return JSONNestedError{
|
||||
return ErrorJSONMarshaller{
|
||||
Subject: ne.subject,
|
||||
Err: ne.err.Error(),
|
||||
Extras: extras,
|
||||
}
|
||||
}
|
||||
|
||||
func (ne NestedError) JSON() []byte {
|
||||
func (ne Error) JSON() []byte {
|
||||
b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -206,19 +206,19 @@ func (ne NestedError) JSON() []byte {
|
|||
return b
|
||||
}
|
||||
|
||||
func (ne NestedError) NoError() bool {
|
||||
func (ne Error) NoError() bool {
|
||||
return ne == nil
|
||||
}
|
||||
|
||||
func (ne NestedError) HasError() bool {
|
||||
func (ne Error) HasError() bool {
|
||||
return ne != nil
|
||||
}
|
||||
|
||||
func errorf(format string, args ...any) NestedError {
|
||||
func errorf(format string, args ...any) Error {
|
||||
return From(fmt.Errorf(format, args...))
|
||||
}
|
||||
|
||||
func fromJSONObject(obj JSONNestedError) (NestedError, bool) {
|
||||
func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
|
@ -226,14 +226,14 @@ func fromJSONObject(obj JSONNestedError) (NestedError, bool) {
|
|||
return FromJSON(data)
|
||||
}
|
||||
|
||||
func (ne NestedError) withError(err NestedError) NestedError {
|
||||
func (ne Error) withError(err Error) Error {
|
||||
if ne != nil && err != nil {
|
||||
ne.extras = append(ne.extras, *err)
|
||||
}
|
||||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) appendMsg(msg string) NestedError {
|
||||
func (ne Error) appendMsg(msg string) Error {
|
||||
if ne == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -241,7 +241,7 @@ func (ne NestedError) appendMsg(msg string) NestedError {
|
|||
return ne
|
||||
}
|
||||
|
||||
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
|
||||
func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) {
|
||||
for range level {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
|
@ -266,7 +266,7 @@ func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ne NestedError) buildError(level int, prefix string) error {
|
||||
func (ne Error) buildError(level int, prefix string) error {
|
||||
var res error
|
||||
var sb strings.Builder
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ func TestErrorIs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestErrorNestedIs(t *testing.T) {
|
||||
var err NestedError
|
||||
var err Error
|
||||
ExpectTrue(t, err.Is(nil))
|
||||
|
||||
err = Failure("some reason")
|
||||
|
@ -40,7 +40,7 @@ func TestErrorNestedIs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIsNil(t *testing.T) {
|
||||
var err NestedError
|
||||
var err Error
|
||||
ExpectTrue(t, err.Is(nil))
|
||||
ExpectTrue(t, err == nil)
|
||||
ExpectTrue(t, err.NoError())
|
||||
|
|
|
@ -22,62 +22,62 @@ var (
|
|||
|
||||
const fmtSubjectWhat = "%w %v: %q"
|
||||
|
||||
func Failure(what string) NestedError {
|
||||
func Failure(what string) Error {
|
||||
return errorf("%s %w", what, ErrFailure)
|
||||
}
|
||||
|
||||
func FailedWhy(what string, why string) NestedError {
|
||||
func FailedWhy(what string, why string) Error {
|
||||
return Failure(what).With(why)
|
||||
}
|
||||
|
||||
func FailWith(what string, err any) NestedError {
|
||||
func FailWith(what string, err any) Error {
|
||||
return Failure(what).With(err)
|
||||
}
|
||||
|
||||
func Invalid(subject, what any) NestedError {
|
||||
func Invalid(subject, what any) Error {
|
||||
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
|
||||
}
|
||||
|
||||
func Unsupported(subject, what any) NestedError {
|
||||
func Unsupported(subject, what any) Error {
|
||||
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
|
||||
}
|
||||
|
||||
func Unexpected(subject, what any) NestedError {
|
||||
func Unexpected(subject, what any) Error {
|
||||
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
|
||||
}
|
||||
|
||||
func UnexpectedError(err error) NestedError {
|
||||
func UnexpectedError(err error) Error {
|
||||
return errorf("%w error: %w", ErrUnexpected, err)
|
||||
}
|
||||
|
||||
func NotExist(subject, what any) NestedError {
|
||||
func NotExist(subject, what any) Error {
|
||||
return errorf("%v %w: %v", subject, ErrNotExists, what)
|
||||
}
|
||||
|
||||
func Missing(subject any) NestedError {
|
||||
func Missing(subject any) Error {
|
||||
return errorf("%w %v", ErrMissing, subject)
|
||||
}
|
||||
|
||||
func Duplicated(subject, what any) NestedError {
|
||||
func Duplicated(subject, what any) Error {
|
||||
return errorf("%w %v: %v", ErrDuplicated, subject, what)
|
||||
}
|
||||
|
||||
func OutOfRange(subject any, value any) NestedError {
|
||||
func OutOfRange(subject any, value any) Error {
|
||||
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
|
||||
}
|
||||
|
||||
func TypeError(subject any, from, to reflect.Type) NestedError {
|
||||
func TypeError(subject any, from, to reflect.Type) Error {
|
||||
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
|
||||
}
|
||||
|
||||
func TypeError2(subject any, from, to reflect.Value) NestedError {
|
||||
func TypeError2(subject any, from, to reflect.Value) Error {
|
||||
return TypeError(subject, from.Type(), to.Type())
|
||||
}
|
||||
|
||||
func TypeMismatch[Expect any](value any) NestedError {
|
||||
func TypeMismatch[Expect any](value any) Error {
|
||||
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
|
||||
}
|
||||
|
||||
func PanicRecv(format string, args ...any) NestedError {
|
||||
func PanicRecv(format string, args ...any) Error {
|
||||
return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
|
15
internal/net/http/loadbalancer/dummy_response_writer.go
Normal file
15
internal/net/http/loadbalancer/dummy_response_writer.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package loadbalancer
|
||||
|
||||
import "net/http"
|
||||
|
||||
type DummyResponseWriter struct{}
|
||||
|
||||
func (w *DummyResponseWriter) Header() (_ http.Header) {
|
||||
return
|
||||
}
|
||||
|
||||
func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (w *DummyResponseWriter) WriteHeader(int) {}
|
|
@ -21,7 +21,7 @@ func (lb *LoadBalancer) newIPHash() impl {
|
|||
if len(lb.Options) == 0 {
|
||||
return impl
|
||||
}
|
||||
var err E.NestedError
|
||||
var err E.Error
|
||||
impl.realIP, err = middleware.NewRealIP(lb.Options)
|
||||
if err != nil {
|
||||
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/middleware"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -54,10 +57,10 @@ func New(cfg *Config) *LoadBalancer {
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
|
||||
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error {
|
||||
lb.startTime = time.Now()
|
||||
lb.task = routeSubtask
|
||||
lb.task.OnComplete("loadbalancer cleanup", func() {
|
||||
lb.task.OnFinished("loadbalancer cleanup", func() {
|
||||
if lb.impl != nil {
|
||||
lb.pool.RangeAll(func(k string, v *Server) {
|
||||
lb.impl.OnRemoveServer(v)
|
||||
|
@ -69,7 +72,7 @@ func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
|
|||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (lb *LoadBalancer) Finish(reason string) {
|
||||
func (lb *LoadBalancer) Finish(reason any) {
|
||||
lb.task.Finish(reason)
|
||||
}
|
||||
|
||||
|
@ -128,7 +131,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
|
|||
|
||||
lb.rebalance()
|
||||
lb.impl.OnAddServer(srv)
|
||||
logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
|
||||
logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
|
@ -147,11 +150,11 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
|||
|
||||
if lb.pool.Size() == 0 {
|
||||
lb.task.Finish("no server left")
|
||||
logger.Infof("[remove] loadbalancer %s stopped", lb.Link)
|
||||
logger.Infof("loadbalancer %s stopped", lb.Link)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
|
||||
logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) rebalance() {
|
||||
|
@ -211,6 +214,21 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
if r.Header.Get(common.HeaderCheckRedirect) != "" {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
|
||||
defer cancel()
|
||||
// send dummy request to wake all servers
|
||||
var dummyRW *DummyResponseWriter
|
||||
for _, srv := range srvs {
|
||||
// wake only if server implements Waker
|
||||
_, ok := srv.handler.(idlewatcher.Waker)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wakeReq := r.Clone(ctx)
|
||||
srv.ServeHTTP(dummyRW, wakeReq)
|
||||
}
|
||||
}
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
|
@ -261,10 +279,9 @@ func (lb *LoadBalancer) String() string {
|
|||
func (lb *LoadBalancer) availServers() []*Server {
|
||||
avail := make([]*Server, 0, lb.pool.Size())
|
||||
lb.pool.RangeAll(func(_ string, srv *Server) {
|
||||
if srv.Status().Bad() {
|
||||
return
|
||||
if srv.Status().Good() {
|
||||
avail = append(avail, srv)
|
||||
}
|
||||
avail = append(avail, srv)
|
||||
})
|
||||
return avail
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@ func (lb *roundRobin) OnAddServer(srv *Server) {}
|
|||
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1)
|
||||
srvs[index%uint32(len(srvs))].ServeHTTP(rw, r)
|
||||
index := lb.index.Add(1) % uint32(len(srvs))
|
||||
srvs[index].ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
lb.index.Store(0)
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
|||
}
|
||||
}
|
||||
|
||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
wl := new(cidrWhitelist)
|
||||
wl.m = &Middleware{
|
||||
impl: wl,
|
||||
|
|
|
@ -33,7 +33,7 @@ var CloudflareRealIP = &realIP{
|
|||
m: &Middleware{withOptions: NewCloudflareRealIP},
|
||||
}
|
||||
|
||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
|
||||
cri := new(realIP)
|
||||
cri.m = &Middleware{
|
||||
impl: cri,
|
||||
|
|
|
@ -36,7 +36,7 @@ var ForwardAuth = &forwardAuth{
|
|||
m: &Middleware{withOptions: NewForwardAuthfunc},
|
||||
}
|
||||
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
fa := new(forwardAuth)
|
||||
fa.forwardAuthOpts = new(forwardAuthOpts)
|
||||
err := Deserialize(optsRaw, fa.forwardAuthOpts)
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
Error = E.NestedError
|
||||
Error = E.Error
|
||||
|
||||
ReverseProxy = gphttp.ReverseProxy
|
||||
ProxyRequest = gphttp.ProxyRequest
|
||||
|
@ -24,7 +24,7 @@ type (
|
|||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *Request)
|
||||
ModifyResponseFunc func(resp *Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
||||
|
||||
OptionsRaw = map[string]any
|
||||
Options any
|
||||
|
@ -77,7 +77,7 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
|
|||
}, "", " ")
|
||||
}
|
||||
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
if len(optsRaw) != 0 && m.withOptions != nil {
|
||||
return m.withOptions(optsRaw)
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
|
|||
}
|
||||
|
||||
// TODO: check conflict or duplicates.
|
||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) {
|
||||
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) {
|
||||
middlewares = make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
invalidM := E.NewBuilder("invalid middlewares")
|
||||
|
@ -136,7 +136,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
|
|||
return
|
||||
}
|
||||
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) {
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
|
||||
var middlewares []*Middleware
|
||||
middlewares, err = createMiddlewares(middlewaresMap)
|
||||
if err != nil {
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
|
||||
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("read middleware compose file", err)
|
||||
|
@ -18,7 +18,7 @@ func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E
|
|||
return BuildMiddlewaresFromYAML(fileContent)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
|
||||
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) {
|
||||
b := E.NewBuilder("middlewares compile errors")
|
||||
defer b.To(&outErr)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ var ModifyRequest = &modifyRequest{
|
|||
m: &Middleware{withOptions: NewModifyRequest},
|
||||
}
|
||||
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyRequest)
|
||||
var mrFunc RewriteFunc
|
||||
if common.IsDebug {
|
||||
|
|
|
@ -24,7 +24,7 @@ var ModifyResponse = &modifyResponse{
|
|||
m: &Middleware{withOptions: NewModifyResponse},
|
||||
}
|
||||
|
||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyResponse)
|
||||
mr.m = &Middleware{impl: mr}
|
||||
if common.IsDebug {
|
||||
|
|
|
@ -26,7 +26,7 @@ var OAuth2 = &oAuth2{
|
|||
m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||
}
|
||||
|
||||
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
oauth := new(oAuth2)
|
||||
oauth.m = &Middleware{
|
||||
impl: oauth,
|
||||
|
|
|
@ -41,7 +41,7 @@ var realIPOptsDefault = func() *realIPOpts {
|
|||
}
|
||||
}
|
||||
|
||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
|
||||
riWithOpts := new(realIP)
|
||||
riWithOpts.m = &Middleware{
|
||||
impl: riWithOpts,
|
||||
|
|
|
@ -72,7 +72,7 @@ type testArgs struct {
|
|||
scheme string
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||
var body io.Reader
|
||||
var rr requestRecorder
|
||||
var proxyURL *url.URL
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
type CIDR net.IPNet
|
||||
|
||||
func (cidr *CIDR) ConvertFrom(val any) E.NestedError {
|
||||
func (cidr *CIDR) ConvertFrom(val any) E.Error {
|
||||
cidrStr, ok := val.(string)
|
||||
if !ok {
|
||||
return E.TypeMismatch[string](val)
|
||||
|
|
|
@ -7,13 +7,7 @@ import (
|
|||
|
||||
type Stream interface {
|
||||
fmt.Stringer
|
||||
net.Listener
|
||||
Setup() error
|
||||
Accept() (conn StreamConn, err error)
|
||||
Handle(conn StreamConn) error
|
||||
CloseListeners()
|
||||
}
|
||||
|
||||
type StreamConn interface {
|
||||
RemoteAddr() net.Addr
|
||||
Close() error
|
||||
Handle(conn net.Conn) error
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package entry
|
||||
|
||||
import (
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
|
@ -18,7 +18,7 @@ type Entry interface {
|
|||
IdlewatcherConfig() *idlewatcher.Config
|
||||
}
|
||||
|
||||
func ValidateEntry(m *RawEntry) (Entry, E.NestedError) {
|
||||
func ValidateEntry(m *RawEntry) (Entry, E.Error) {
|
||||
m.FillMissingFields()
|
||||
|
||||
scheme, err := T.NewScheme(m.Scheme)
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
|
||||
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.Error) {
|
||||
h := make(http.Header)
|
||||
for k, v := range headers {
|
||||
vSplit := strings.Split(v, ",")
|
||||
|
|
|
@ -9,6 +9,6 @@ type (
|
|||
Subdomain = Alias
|
||||
)
|
||||
|
||||
func ValidateHost[String ~string](s String) (Host, E.NestedError) {
|
||||
func ValidateHost[String ~string](s String) (Host, E.Error) {
|
||||
return Host(s), nil
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ type (
|
|||
|
||||
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
|
||||
|
||||
func ValidatePathPattern(s string) (PathPattern, E.NestedError) {
|
||||
func ValidatePathPattern(s string) (PathPattern, E.Error) {
|
||||
if len(s) == 0 {
|
||||
return "", E.Invalid("path", "must not be empty")
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ func ValidatePathPattern(s string) (PathPattern, E.NestedError) {
|
|||
return PathPattern(s), nil
|
||||
}
|
||||
|
||||
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
|
||||
func ValidatePathPatterns(s []string) (PathPatterns, E.Error) {
|
||||
if len(s) == 0 {
|
||||
return []PathPattern{"/"}, nil
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
type Port int
|
||||
|
||||
func ValidatePort[String ~string](v String) (Port, E.NestedError) {
|
||||
func ValidatePort[String ~string](v String) (Port, E.Error) {
|
||||
p, err := strconv.Atoi(string(v))
|
||||
if err != nil {
|
||||
return ErrPort, E.Invalid("port number", v).With(err)
|
||||
|
@ -16,7 +16,7 @@ func ValidatePort[String ~string](v String) (Port, E.NestedError) {
|
|||
return ValidatePortInt(p)
|
||||
}
|
||||
|
||||
func ValidatePortInt[Int int | uint16](v Int) (Port, E.NestedError) {
|
||||
func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) {
|
||||
p := Port(v)
|
||||
if !p.inBound() {
|
||||
return ErrPort, E.OutOfRange("port", p)
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
type Scheme string
|
||||
|
||||
func NewScheme[String ~string](s String) (Scheme, E.NestedError) {
|
||||
func NewScheme[String ~string](s String) (Scheme, E.Error) {
|
||||
switch s {
|
||||
case "http", "https", "tcp", "udp":
|
||||
return Scheme(s), nil
|
||||
|
|
|
@ -12,7 +12,7 @@ type StreamPort struct {
|
|||
ProxyPort Port `json:"proxy"`
|
||||
}
|
||||
|
||||
func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) {
|
||||
func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
|
||||
split := strings.Split(p, ":")
|
||||
|
||||
switch len(split) {
|
||||
|
@ -47,7 +47,7 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) {
|
|||
return StreamPort{listeningPort, proxyPort}, nil
|
||||
}
|
||||
|
||||
func parseNameToPort(name string) (Port, E.NestedError) {
|
||||
func parseNameToPort(name string) (Port, E.Error) {
|
||||
port, ok := common.ServiceNamePortMapTCP[name]
|
||||
if !ok {
|
||||
return ErrPort, E.Invalid("service", name)
|
||||
|
|
|
@ -12,7 +12,7 @@ type StreamScheme struct {
|
|||
ProxyScheme Scheme `json:"proxy"`
|
||||
}
|
||||
|
||||
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
|
||||
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) {
|
||||
ss = &StreamScheme{}
|
||||
parts := strings.Split(s, ":")
|
||||
if len(parts) == 1 {
|
||||
|
|
|
@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
|
||||
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
|
||||
var trans *http.Transport
|
||||
|
||||
if entry.NoTLSVerify {
|
||||
|
@ -97,7 +97,7 @@ func (r *HTTPRoute) String() string {
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
|
@ -151,7 +151,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
|||
r.addToLoadBalancer()
|
||||
} else {
|
||||
httpRoutes.Store(string(r.Alias), r)
|
||||
r.task.OnComplete("stop rp", func() {
|
||||
r.task.OnFinished("remove from route table", func() {
|
||||
httpRoutes.Delete(string(r.Alias))
|
||||
})
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
|
|||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (r *HTTPRoute) Finish(reason string) {
|
||||
func (r *HTTPRoute) Finish(reason any) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
|
@ -175,8 +175,8 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
|||
}
|
||||
} else {
|
||||
lb = loadbalancer.New(r.LoadBalance)
|
||||
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
|
||||
lbTask.OnComplete("remove lb from routes", func() {
|
||||
lbTask := r.task.Parent().Subtask("loadbalancer " + r.LoadBalance.Link)
|
||||
lbTask.OnCancel("remove lb from routes", func() {
|
||||
httpRoutes.Delete(r.LoadBalance.Link)
|
||||
})
|
||||
lb.Start(lbTask)
|
||||
|
@ -194,9 +194,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
|
|||
httpRoutes.Store(r.LoadBalance.Link, linked)
|
||||
}
|
||||
r.loadBalancer = lb
|
||||
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
r.server = loadbalancer.NewServer(r.task.String(), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
|
||||
lb.AddServer(r.server)
|
||||
r.task.OnComplete("remove server from lb", func() {
|
||||
r.task.OnCancel("remove server from lb", func() {
|
||||
lb.RemoveServer(r.server)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ var (
|
|||
AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
|
||||
)
|
||||
|
||||
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) {
|
||||
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) {
|
||||
if dockerHost == common.DockerHostFromEnv {
|
||||
dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost)
|
||||
}
|
||||
|
@ -40,18 +40,18 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
|
|||
return W.NewDockerWatcher(p.dockerHost)
|
||||
}
|
||||
|
||||
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
|
||||
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) {
|
||||
routes = R.NewRoutes()
|
||||
entries := entry.NewProxyEntries()
|
||||
|
||||
info, err := D.GetClientInfo(p.dockerHost, true)
|
||||
containers, err := D.ListContainers(p.dockerHost)
|
||||
if err != nil {
|
||||
return routes, E.FailWith("connect to docker", err)
|
||||
return routes, err
|
||||
}
|
||||
|
||||
errors := E.NewBuilder("errors in docker labels")
|
||||
|
||||
for _, c := range info.Containers {
|
||||
for _, c := range containers {
|
||||
container := D.FromDocker(&c, p.dockerHost)
|
||||
if container.IsExcluded {
|
||||
continue
|
||||
|
@ -70,10 +70,6 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
|
|||
})
|
||||
}
|
||||
|
||||
entries.RangeAll(func(_ string, e *entry.RawEntry) {
|
||||
e.Container.DockerHost = p.dockerHost
|
||||
})
|
||||
|
||||
routes, err = R.FromEntries(entries)
|
||||
errors.Add(err)
|
||||
|
||||
|
@ -89,7 +85,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
|
|||
|
||||
// Returns a list of proxy entries for a container.
|
||||
// Always non-nil.
|
||||
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
|
||||
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) {
|
||||
entries = entry.NewProxyEntries()
|
||||
|
||||
if p.shouldIgnore(container) {
|
||||
|
@ -117,7 +113,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
|
|||
return entries, errors.Build().Subject(container.ContainerName)
|
||||
}
|
||||
|
||||
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
|
||||
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) {
|
||||
b := E.NewBuilder("errors in label %s", key)
|
||||
defer b.To(&res)
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
|
@ -32,31 +34,52 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
|
|||
return
|
||||
}
|
||||
|
||||
oldRoutes.RangeAll(func(k string, v *route.Route) {
|
||||
if !newRoutes.Has(k) {
|
||||
handler.Remove(v)
|
||||
if common.IsDebug {
|
||||
eventsLog := E.NewBuilder("events")
|
||||
for _, event := range events {
|
||||
eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID)
|
||||
}
|
||||
handler.provider.l.Debug(eventsLog.String())
|
||||
oldRoutesLog := E.NewBuilder("old routes")
|
||||
oldRoutes.RangeAll(func(k string, r *route.Route) {
|
||||
oldRoutesLog.Addf(k)
|
||||
})
|
||||
handler.provider.l.Debug(oldRoutesLog.String())
|
||||
newRoutesLog := E.NewBuilder("new routes")
|
||||
newRoutes.RangeAll(func(k string, r *route.Route) {
|
||||
newRoutesLog.Addf(k)
|
||||
})
|
||||
handler.provider.l.Debug(newRoutesLog.String())
|
||||
}
|
||||
|
||||
oldRoutes.RangeAll(func(k string, oldr *route.Route) {
|
||||
newr, ok := newRoutes.Load(k)
|
||||
if !ok {
|
||||
handler.Remove(oldr)
|
||||
} else if handler.matchAny(events, newr) {
|
||||
handler.Update(parent, oldr, newr)
|
||||
} else if entry.ShouldNotServe(newr) {
|
||||
handler.Remove(oldr)
|
||||
}
|
||||
})
|
||||
newRoutes.RangeAll(func(k string, newr *route.Route) {
|
||||
if oldRoutes.Has(k) {
|
||||
for _, ev := range events {
|
||||
if handler.match(ev, newr) {
|
||||
old, ok := oldRoutes.Load(k)
|
||||
if !ok { // should not happen
|
||||
panic("race condition")
|
||||
}
|
||||
handler.Update(parent, old, newr)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !(oldRoutes.Has(k) || entry.ShouldNotServe(newr)) {
|
||||
handler.Add(parent, newr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route) bool {
|
||||
for _, event := range events {
|
||||
if handler.match(event, route) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
|
||||
switch handler.provider.t {
|
||||
switch handler.provider.GetType() {
|
||||
case ProviderTypeDocker:
|
||||
return route.Entry.Container.ContainerID == event.ActorID ||
|
||||
route.Entry.Container.ContainerName == event.ActorName
|
||||
|
@ -70,14 +93,15 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
|
|||
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
|
||||
err := handler.provider.startRoute(parent, route)
|
||||
if err != nil {
|
||||
handler.errs.Add(err)
|
||||
handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err))
|
||||
} else {
|
||||
handler.added = append(handler.added, route.Entry.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *EventHandler) Remove(route *route.Route) {
|
||||
route.Finish("route removal")
|
||||
route.Finish("route removed")
|
||||
handler.provider.routes.Delete(route.Entry.Alias)
|
||||
handler.removed = append(handler.removed, route.Entry.Alias)
|
||||
}
|
||||
|
||||
|
@ -85,7 +109,7 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
|
|||
oldRoute.Finish("route update")
|
||||
err := handler.provider.startRoute(parent, newRoute)
|
||||
if err != nil {
|
||||
handler.errs.Add(err)
|
||||
handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err))
|
||||
} else {
|
||||
handler.updated = append(handler.updated, newRoute.Entry.Alias)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type FileProvider struct {
|
|||
path string
|
||||
}
|
||||
|
||||
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
|
||||
func FileProviderImpl(filename string) (ProviderImpl, E.Error) {
|
||||
impl := &FileProvider{
|
||||
fileName: filename,
|
||||
path: path.Join(common.ConfigBasePath, filename),
|
||||
|
@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
|
|||
}
|
||||
}
|
||||
|
||||
func Validate(data []byte) E.NestedError {
|
||||
func Validate(data []byte) E.Error {
|
||||
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ func (p FileProvider) String() string {
|
|||
return p.fileName
|
||||
}
|
||||
|
||||
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
|
||||
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) {
|
||||
routes = R.NewRoutes()
|
||||
|
||||
b := E.NewBuilder("validation failure")
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
|
||||
"github.com/sirupsen/logrus"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/proxy/entry"
|
||||
R "github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
|
@ -29,7 +28,7 @@ type (
|
|||
ProviderImpl interface {
|
||||
fmt.Stringer
|
||||
NewWatcher() W.Watcher
|
||||
LoadRoutesImpl() (R.Routes, E.NestedError)
|
||||
LoadRoutesImpl() (R.Routes, E.Error)
|
||||
}
|
||||
ProviderType string
|
||||
ProviderStats struct {
|
||||
|
@ -43,7 +42,7 @@ const (
|
|||
ProviderTypeDocker ProviderType = "docker"
|
||||
ProviderTypeFile ProviderType = "file"
|
||||
|
||||
providerEventFlushInterval = 500 * time.Millisecond
|
||||
providerEventFlushInterval = 300 * time.Millisecond
|
||||
)
|
||||
|
||||
func newProvider(name string, t ProviderType) *Provider {
|
||||
|
@ -56,7 +55,7 @@ func newProvider(name string, t ProviderType) *Provider {
|
|||
return p
|
||||
}
|
||||
|
||||
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
|
||||
func NewFileProvider(filename string) (p *Provider, err E.Error) {
|
||||
name := path.Base(filename)
|
||||
if name == "" {
|
||||
return nil, E.Invalid("file name", "empty")
|
||||
|
@ -70,7 +69,7 @@ func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
|
|||
return
|
||||
}
|
||||
|
||||
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
|
||||
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) {
|
||||
if name == "" {
|
||||
return nil, E.Invalid("provider name", "empty")
|
||||
}
|
||||
|
@ -101,18 +100,16 @@ func (p *Provider) MarshalText() ([]byte, error) {
|
|||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
|
||||
if entry.UseLoadBalance(r) {
|
||||
r.Entry.Alias = p.String() + "/" + r.Entry.Alias
|
||||
}
|
||||
subtask := parent.Subtask(r.Entry.Alias)
|
||||
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
|
||||
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
|
||||
err := r.Start(subtask)
|
||||
if err != nil {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
subtask.Finish(err.String()) // just to ensure
|
||||
subtask.Finish(err) // just to ensure
|
||||
return err
|
||||
} else {
|
||||
subtask.OnComplete("del from provider", func() {
|
||||
p.routes.Store(r.Entry.Alias, r)
|
||||
subtask.OnFinished("del from provider", func() {
|
||||
p.routes.Delete(r.Entry.Alias)
|
||||
})
|
||||
}
|
||||
|
@ -120,7 +117,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
|
||||
func (p *Provider) Start(configSubtask task.Task) (res E.Error) {
|
||||
errors := E.NewBuilder("errors starting routes")
|
||||
defer errors.To(&res)
|
||||
|
||||
|
@ -141,7 +138,7 @@ func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
|
|||
handler.Log()
|
||||
flushTask.Finish("events flushed")
|
||||
},
|
||||
func(err E.NestedError) {
|
||||
func(err E.Error) {
|
||||
p.l.Error(err)
|
||||
},
|
||||
)
|
||||
|
@ -157,8 +154,8 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) {
|
|||
return p.routes.Load(alias)
|
||||
}
|
||||
|
||||
func (p *Provider) LoadRoutes() E.NestedError {
|
||||
var err E.NestedError
|
||||
func (p *Provider) LoadRoutes() E.Error {
|
||||
var err E.Error
|
||||
p.routes, err = p.LoadRoutesImpl()
|
||||
if p.routes.Size() > 0 {
|
||||
return err
|
||||
|
|
94
internal/route/raw.go
Normal file
94
internal/route/raw.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
RawStream struct {
|
||||
*StreamRoute
|
||||
|
||||
listener net.Listener
|
||||
targetAddr net.Addr
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
streamBufferSize = 8192
|
||||
streamDialTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func NewRawStreamRoute(base *StreamRoute) *RawStream {
|
||||
return &RawStream{
|
||||
StreamRoute: base,
|
||||
}
|
||||
}
|
||||
|
||||
func (route *RawStream) Setup() error {
|
||||
var lcfg net.ListenConfig
|
||||
var err error
|
||||
|
||||
switch route.Scheme.ListeningScheme {
|
||||
case "tcp":
|
||||
route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
|
||||
route.listener = tcpListener
|
||||
case "udp":
|
||||
route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port)
|
||||
route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener)
|
||||
default:
|
||||
return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (route *RawStream) Accept() (net.Conn, error) {
|
||||
if route.listener == nil {
|
||||
return nil, errors.New("listener not yet set up")
|
||||
}
|
||||
return route.listener.Accept()
|
||||
}
|
||||
|
||||
func (route *RawStream) Handle(c net.Conn) error {
|
||||
clientConn := c.(net.Conn)
|
||||
|
||||
defer clientConn.Close()
|
||||
route.task.OnCancel("close conn", func() { clientConn.Close() })
|
||||
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
|
||||
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
|
||||
serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
|
||||
return pipe.Start()
|
||||
}
|
||||
|
||||
func (route *RawStream) Close() error {
|
||||
return route.listener.Close()
|
||||
}
|
|
@ -44,7 +44,7 @@ func (rt *Route) Container() *docker.Container {
|
|||
return rt.Entry.Container
|
||||
}
|
||||
|
||||
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
|
||||
func NewRoute(raw *entry.RawEntry) (*Route, E.Error) {
|
||||
en, err := entry.ValidateEntry(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -73,7 +73,7 @@ func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
|
||||
func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
|
||||
b := E.NewBuilder("errors in routes")
|
||||
|
||||
routes := NewRoutes()
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
stdNet "net"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -37,7 +36,7 @@ func GetStreamProxies() F.Map[string, *StreamRoute] {
|
|||
return streamRoutes
|
||||
}
|
||||
|
||||
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
|
||||
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
|
||||
// TODO: support non-coherent scheme
|
||||
if !entry.Scheme.IsCoherent() {
|
||||
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
|
||||
|
@ -48,16 +47,12 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Finish(reason string) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) String() string {
|
||||
return fmt.Sprintf("stream %s", r.Alias)
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
||||
func (r *StreamRoute) Start(providerSubtask task.Task) E.Error {
|
||||
if entry.ShouldNotServe(r) {
|
||||
providerSubtask.Finish("should not serve")
|
||||
return nil
|
||||
|
@ -71,11 +66,13 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
|||
r.HealthCheck.Disable = true
|
||||
}
|
||||
|
||||
if r.Scheme.ListeningScheme.IsTCP() {
|
||||
r.Stream = NewTCPRoute(r)
|
||||
} else {
|
||||
r.Stream = NewUDPRoute(r)
|
||||
}
|
||||
// if r.Scheme.ListeningScheme.IsTCP() {
|
||||
// r.Stream = NewTCPRoute(r)
|
||||
// } else {
|
||||
// r.Stream = NewUDPRoute(r)
|
||||
// }
|
||||
r.task = providerSubtask
|
||||
r.Stream = NewRawStreamRoute(r)
|
||||
r.l = logrus.WithField("route", r.Stream.String())
|
||||
|
||||
switch {
|
||||
|
@ -83,6 +80,7 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
|||
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
|
||||
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
|
||||
if err != nil {
|
||||
r.task.Finish(err)
|
||||
return err
|
||||
}
|
||||
r.Stream = waker
|
||||
|
@ -90,24 +88,41 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
|
|||
case entry.UseHealthCheck(r):
|
||||
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
|
||||
}
|
||||
r.task = providerSubtask
|
||||
r.task.OnComplete("stop stream", r.CloseListeners)
|
||||
|
||||
if err := r.Setup(); err != nil {
|
||||
r.task.Finish(err)
|
||||
return E.FailWith("setup", err)
|
||||
}
|
||||
r.l.Infof("listening on port %d", r.Port.ListeningPort)
|
||||
|
||||
go r.acceptConnections()
|
||||
r.task.OnFinished("close stream", func() {
|
||||
if err := r.Close(); err != nil {
|
||||
r.l.Error("close stream error: ", err)
|
||||
}
|
||||
})
|
||||
r.task.OnFinished("remove from route table", func() {
|
||||
streamRoutes.Delete(string(r.Alias))
|
||||
})
|
||||
|
||||
r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort)
|
||||
|
||||
if r.HealthMon != nil {
|
||||
r.HealthMon.Start(r.task.Subtask("health monitor"))
|
||||
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
|
||||
logrus.Warn("health monitor error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
go r.acceptConnections()
|
||||
streamRoutes.Store(string(r.Alias), r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Finish(reason any) {
|
||||
r.task.Finish(reason)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) acceptConnections() {
|
||||
defer r.task.Finish("listener closed")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
|
@ -117,24 +132,17 @@ func (r *StreamRoute) acceptConnections() {
|
|||
if err != nil {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
return
|
||||
default:
|
||||
var nErr *stdNet.OpError
|
||||
ok := errors.As(err, &nErr)
|
||||
if !(ok && nErr.Timeout()) {
|
||||
r.l.Error("accept connection error: ", err)
|
||||
r.task.Finish(err.Error())
|
||||
return
|
||||
}
|
||||
continue
|
||||
r.l.Error("accept connection error: ", err)
|
||||
r.task.Finish(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr()))
|
||||
go func() {
|
||||
err := r.Handle(conn)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
r.l.Error(err)
|
||||
connTask.Finish(err.Error())
|
||||
} else {
|
||||
connTask.Finish("connection closed")
|
||||
}
|
||||
|
|
|
@ -1,71 +1,68 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "net"
|
||||
// "time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
// "github.com/yusing/go-proxy/internal/net/types"
|
||||
// T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
// U "github.com/yusing/go-proxy/internal/utils"
|
||||
// F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
// )
|
||||
|
||||
const tcpDialTimeout = 5 * time.Second
|
||||
// const tcpDialTimeout = 5 * time.Second
|
||||
|
||||
type (
|
||||
TCPConnMap = F.Map[net.Conn, struct{}]
|
||||
TCPRoute struct {
|
||||
*StreamRoute
|
||||
listener *net.TCPListener
|
||||
}
|
||||
)
|
||||
// type (
|
||||
// TCPConnMap = F.Map[net.Conn, struct{}]
|
||||
// TCPRoute struct {
|
||||
// *StreamRoute
|
||||
// listener *net.TCPListener
|
||||
// }
|
||||
// )
|
||||
|
||||
func NewTCPRoute(base *StreamRoute) *TCPRoute {
|
||||
return &TCPRoute{StreamRoute: base}
|
||||
}
|
||||
// func NewTCPRoute(base *StreamRoute) *TCPRoute {
|
||||
// return &TCPRoute{StreamRoute: base}
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) Setup() error {
|
||||
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//! this read the allocated port from original ':0'
|
||||
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
|
||||
route.listener = in.(*net.TCPListener)
|
||||
return nil
|
||||
}
|
||||
// func (route *TCPRoute) Setup() error {
|
||||
// var cfg net.ListenConfig
|
||||
// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// //! this read the allocated port from original ':0'
|
||||
// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
|
||||
// route.listener = in.(*net.TCPListener)
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) Accept() (types.StreamConn, error) {
|
||||
route.listener.SetDeadline(time.Now().Add(time.Second))
|
||||
return route.listener.Accept()
|
||||
}
|
||||
// func (route *TCPRoute) Accept() (types.StreamConn, error) {
|
||||
// return route.listener.Accept()
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) Handle(c types.StreamConn) error {
|
||||
clientConn := c.(net.Conn)
|
||||
// func (route *TCPRoute) Handle(c types.StreamConn) error {
|
||||
// clientConn := c.(net.Conn)
|
||||
|
||||
defer clientConn.Close()
|
||||
route.task.OnComplete("close conn", func() { clientConn.Close() })
|
||||
// defer clientConn.Close()
|
||||
// route.task.OnCancel("close conn", func() { clientConn.Close() })
|
||||
|
||||
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
|
||||
// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
|
||||
|
||||
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
|
||||
dialer := &net.Dialer{}
|
||||
// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
|
||||
// dialer := &net.Dialer{}
|
||||
|
||||
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
|
||||
// cancel()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
|
||||
return pipe.Start()
|
||||
}
|
||||
// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
|
||||
// return pipe.Start()
|
||||
// }
|
||||
|
||||
func (route *TCPRoute) CloseListeners() {
|
||||
if route.listener == nil {
|
||||
return
|
||||
}
|
||||
route.listener.Close()
|
||||
}
|
||||
// func (route *TCPRoute) Close() error {
|
||||
// return route.listener.Close()
|
||||
// }
|
||||
|
|
|
@ -1,145 +1,149 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
// import (
|
||||
// "errors"
|
||||
// "fmt"
|
||||
// "io"
|
||||
// "net"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
// "github.com/yusing/go-proxy/internal/net/types"
|
||||
// T "github.com/yusing/go-proxy/internal/proxy/fields"
|
||||
// U "github.com/yusing/go-proxy/internal/utils"
|
||||
// F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
// )
|
||||
|
||||
type (
|
||||
UDPRoute struct {
|
||||
*StreamRoute
|
||||
// type (
|
||||
// UDPRoute struct {
|
||||
// *StreamRoute
|
||||
|
||||
connMap UDPConnMap
|
||||
// connMap UDPConnMap
|
||||
|
||||
listeningConn *net.UDPConn
|
||||
targetAddr *net.UDPAddr
|
||||
}
|
||||
UDPConn struct {
|
||||
key string
|
||||
src *net.UDPConn
|
||||
dst *net.UDPConn
|
||||
U.BidirectionalPipe
|
||||
}
|
||||
UDPConnMap = F.Map[string, *UDPConn]
|
||||
)
|
||||
// listeningConn net.PacketConn
|
||||
// targetAddr *net.UDPAddr
|
||||
// }
|
||||
// UDPConn struct {
|
||||
// key string
|
||||
// src net.Conn
|
||||
// dst net.Conn
|
||||
// U.BidirectionalPipe
|
||||
// }
|
||||
// UDPConnMap = F.Map[string, *UDPConn]
|
||||
// )
|
||||
|
||||
var NewUDPConnMap = F.NewMap[UDPConnMap]
|
||||
// var NewUDPConnMap = F.NewMap[UDPConnMap]
|
||||
|
||||
const udpBufferSize = 8192
|
||||
// const udpBufferSize = 8192
|
||||
|
||||
func NewUDPRoute(base *StreamRoute) *UDPRoute {
|
||||
return &UDPRoute{
|
||||
StreamRoute: base,
|
||||
connMap: NewUDPConnMap(),
|
||||
}
|
||||
}
|
||||
// func NewUDPRoute(base *StreamRoute) *UDPRoute {
|
||||
// return &UDPRoute{
|
||||
// StreamRoute: base,
|
||||
// connMap: NewUDPConnMap(),
|
||||
// }
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) Setup() error {
|
||||
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
if err != nil {
|
||||
source.Close()
|
||||
return err
|
||||
}
|
||||
// func (route *UDPRoute) Setup() error {
|
||||
// var cfg net.ListenConfig
|
||||
// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
|
||||
// if err != nil {
|
||||
// source.Close()
|
||||
// return err
|
||||
// }
|
||||
|
||||
//! this read the allocated listeningPort from original ':0'
|
||||
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
|
||||
// //! this read the allocated listeningPort from original ':0'
|
||||
// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
route.listeningConn = source
|
||||
route.targetAddr = raddr
|
||||
// route.listeningConn = source
|
||||
// route.targetAddr = raddr
|
||||
|
||||
return nil
|
||||
}
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) Accept() (types.StreamConn, error) {
|
||||
in := route.listeningConn
|
||||
// func (route *UDPRoute) Accept() (types.StreamConn, error) {
|
||||
// in := route.listeningConn
|
||||
|
||||
buffer := make([]byte, udpBufferSize)
|
||||
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second))
|
||||
nRead, srcAddr, err := in.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// buffer := make([]byte, udpBufferSize)
|
||||
// nRead, srcAddr, err := in.ReadFrom(buffer)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
if nRead == 0 {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
// if nRead == 0 {
|
||||
// return nil, io.ErrShortBuffer
|
||||
// }
|
||||
|
||||
key := srcAddr.String()
|
||||
conn, ok := route.connMap.Load(key)
|
||||
// key := srcAddr.String()
|
||||
// conn, ok := route.connMap.Load(key)
|
||||
|
||||
if !ok {
|
||||
srcConn, err := net.DialUDP("udp", nil, srcAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
|
||||
if err != nil {
|
||||
srcConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
conn = &UDPConn{
|
||||
key,
|
||||
srcConn,
|
||||
dstConn,
|
||||
U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
|
||||
}
|
||||
route.connMap.Store(key, conn)
|
||||
}
|
||||
// if !ok {
|
||||
// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String())
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String())
|
||||
// if err != nil {
|
||||
// srcConn.Close()
|
||||
// return nil, err
|
||||
// }
|
||||
// conn = &UDPConn{
|
||||
// key,
|
||||
// srcConn,
|
||||
// dstConn,
|
||||
// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
|
||||
// }
|
||||
// route.connMap.Store(key, conn)
|
||||
// }
|
||||
|
||||
_, err = conn.dst.Write(buffer[:nRead])
|
||||
return conn, err
|
||||
}
|
||||
// _, err = conn.dst.Write(buffer[:nRead])
|
||||
// return conn, err
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) Handle(c types.StreamConn) error {
|
||||
conn := c.(*UDPConn)
|
||||
err := conn.Start()
|
||||
route.connMap.Delete(conn.key)
|
||||
return err
|
||||
}
|
||||
// func (route *UDPRoute) Handle(c types.StreamConn) error {
|
||||
// switch c := c.(type) {
|
||||
// case *UDPConn:
|
||||
// err := c.Start()
|
||||
// route.connMap.Delete(c.key)
|
||||
// c.Close()
|
||||
// return err
|
||||
// case *net.TCPConn:
|
||||
// in := route.listeningConn
|
||||
// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start()
|
||||
// c.Close()
|
||||
// return err
|
||||
// }
|
||||
// return fmt.Errorf("unknown conn type: %T", c)
|
||||
// }
|
||||
|
||||
func (route *UDPRoute) CloseListeners() {
|
||||
if route.listeningConn != nil {
|
||||
route.listeningConn.Close()
|
||||
}
|
||||
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
|
||||
if err := conn.Close(); err != nil {
|
||||
route.l.Errorf("error closing conn: %s", err)
|
||||
}
|
||||
})
|
||||
route.connMap.Clear()
|
||||
}
|
||||
// func (route *UDPRoute) Close() error {
|
||||
// route.connMap.RangeAllParallel(func(k string, v *UDPConn) {
|
||||
// v.Close()
|
||||
// })
|
||||
// route.connMap.Clear()
|
||||
// return route.listeningConn.Close()
|
||||
// }
|
||||
|
||||
// Close implements types.StreamConn
|
||||
func (conn *UDPConn) Close() error {
|
||||
return errors.Join(conn.src.Close(), conn.dst.Close())
|
||||
}
|
||||
// // Close implements types.StreamConn
|
||||
// func (conn *UDPConn) Close() error {
|
||||
// return errors.Join(conn.src.Close(), conn.dst.Close())
|
||||
// }
|
||||
|
||||
// RemoteAddr implements types.StreamConn
|
||||
func (conn *UDPConn) RemoteAddr() net.Addr {
|
||||
return conn.src.RemoteAddr()
|
||||
}
|
||||
// // RemoteAddr implements types.StreamConn
|
||||
// func (conn *UDPConn) RemoteAddr() net.Addr {
|
||||
// return conn.src.RemoteAddr()
|
||||
// }
|
||||
|
||||
type sourceRWCloser struct {
|
||||
server *net.UDPConn
|
||||
*net.UDPConn
|
||||
}
|
||||
// type sourceRWCloser struct {
|
||||
// server net.PacketConn
|
||||
// net.Conn
|
||||
// }
|
||||
|
||||
func (w sourceRWCloser) Write(p []byte) (int, error) {
|
||||
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
|
||||
}
|
||||
// func (w sourceRWCloser) Write(p []byte) (int, error) {
|
||||
// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr))
|
||||
// }
|
||||
|
|
73
internal/route/udp_listener.go
Normal file
73
internal/route/udp_listener.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
UDPListener struct {
|
||||
ctx context.Context
|
||||
listener net.PacketConn
|
||||
connMap UDPConnMap
|
||||
mu sync.Mutex
|
||||
}
|
||||
UDPConnMap = F.Map[string, net.Conn]
|
||||
)
|
||||
|
||||
var NewUDPConnMap = F.NewMap[UDPConnMap]
|
||||
|
||||
func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener {
|
||||
return &UDPListener{
|
||||
ctx: ctx,
|
||||
listener: listener,
|
||||
connMap: NewUDPConnMap(),
|
||||
}
|
||||
}
|
||||
|
||||
// Addr implements net.Listener.
|
||||
func (route *UDPListener) Addr() net.Addr {
|
||||
return route.listener.LocalAddr()
|
||||
}
|
||||
|
||||
func (udpl *UDPListener) Accept() (net.Conn, error) {
|
||||
in := udpl.listener
|
||||
|
||||
buffer := make([]byte, streamBufferSize)
|
||||
nRead, srcAddr, err := in.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if nRead == 0 {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
udpl.mu.Lock()
|
||||
defer udpl.mu.Unlock()
|
||||
|
||||
key := srcAddr.String()
|
||||
conn, ok := udpl.connMap.Load(key)
|
||||
if !ok {
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpl.connMap.Store(key, srcConn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close implements net.Listener.
|
||||
func (route *UDPListener) Close() error {
|
||||
route.connMap.RangeAllParallel(func(key string, conn net.Conn) {
|
||||
conn.Close()
|
||||
})
|
||||
route.connMap.Clear()
|
||||
return route.listener.Close()
|
||||
}
|
|
@ -116,7 +116,7 @@ func (s *Server) Start() {
|
|||
}()
|
||||
}
|
||||
|
||||
s.task.OnComplete("stop server", s.stop)
|
||||
s.task.OnFinished("stop server", s.stop)
|
||||
}
|
||||
|
||||
func (s *Server) stop() {
|
||||
|
|
|
@ -39,10 +39,11 @@ type (
|
|||
// Use Task.Finish to stop all subtasks of the task.
|
||||
Task interface {
|
||||
TaskFinisher
|
||||
fmt.Stringer
|
||||
// Name returns the name of the task.
|
||||
Name() string
|
||||
// Context returns the context associated with the task. This context is
|
||||
// canceled when Finish is called.
|
||||
// canceled when Finish of the task is called, or parent task is canceled.
|
||||
Context() context.Context
|
||||
// FinishCause returns the reason / error that caused the task to be finished.
|
||||
FinishCause() error
|
||||
|
@ -53,12 +54,16 @@ type (
|
|||
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
|
||||
//
|
||||
// This should not be called after Finish, Wait, or WaitSubTasks is called.
|
||||
Subtask(usageFmt string, args ...any) Task
|
||||
// OnComplete calls fn when the task and all subtasks are finished.
|
||||
Subtask(name string) Task
|
||||
// OnFinished calls fn when all subtasks are finished.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
OnComplete(about string, fn func())
|
||||
// Wait waits for all subtasks, itself and all OnComplete to finish.
|
||||
OnFinished(about string, fn func())
|
||||
// OnCancel calls fn when the task is canceled.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
OnCancel(about string, fn func())
|
||||
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
|
||||
//
|
||||
// It must be called only after Finish is called.
|
||||
Wait()
|
||||
|
@ -76,37 +81,46 @@ type (
|
|||
// The task passed must be a subtask of the caller task.
|
||||
//
|
||||
// callerSubtask.Finish must be called when start fails or the object is finished.
|
||||
Start(callerSubtask Task) E.NestedError
|
||||
Start(callerSubtask Task) E.Error
|
||||
}
|
||||
TaskFinisher interface {
|
||||
// Finish marks the task as finished by cancelling its context.
|
||||
// Finish marks the task as finished and cancel its context.
|
||||
//
|
||||
// Then call Wait to wait for all subtasks and OnComplete of the task to finish.
|
||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
||||
// of the task to finish.
|
||||
//
|
||||
// Note that it will also cancel all subtasks.
|
||||
Finish(reason string)
|
||||
Finish(reason any)
|
||||
}
|
||||
task struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
|
||||
parent *task
|
||||
subtasks *xsync.MapOf[*task, struct{}]
|
||||
parent *task
|
||||
subtasks *xsync.MapOf[*task, struct{}]
|
||||
subTasksWg sync.WaitGroup
|
||||
|
||||
name, line string
|
||||
|
||||
subTasksWg, onCompleteWg sync.WaitGroup
|
||||
OnFinishedFuncs []func()
|
||||
OnFinishedMu sync.Mutex
|
||||
onFinishedWg sync.WaitGroup
|
||||
|
||||
finishOnce sync.Once
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProgramExiting = errors.New("program exiting")
|
||||
ErrTaskCancelled = errors.New("task cancelled")
|
||||
ErrTaskCanceled = errors.New("task canceled")
|
||||
)
|
||||
|
||||
// GlobalTask returns a new Task with the given name, derived from the global context.
|
||||
func GlobalTask(format string, args ...any) Task {
|
||||
return globalTask.Subtask(format, args...)
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
return globalTask.Subtask(format)
|
||||
}
|
||||
|
||||
// DebugTaskMap returns a map[string]any representation of the global task tree.
|
||||
|
@ -155,6 +169,10 @@ func (t *task) Name() string {
|
|||
return t.name
|
||||
}
|
||||
|
||||
func (t *task) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *task) Context() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
@ -171,43 +189,83 @@ func (t *task) Parent() Task {
|
|||
return t.parent
|
||||
}
|
||||
|
||||
func (t *task) OnComplete(about string, fn func()) {
|
||||
t.onCompleteWg.Add(1)
|
||||
func (t *task) runAllOnFinished(onCompTask Task) {
|
||||
<-t.ctx.Done()
|
||||
t.WaitSubTasks()
|
||||
for _, OnFinishedFunc := range t.OnFinishedFuncs {
|
||||
OnFinishedFunc()
|
||||
t.onFinishedWg.Done()
|
||||
}
|
||||
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
|
||||
}
|
||||
|
||||
func (t *task) OnFinished(about string, fn func()) {
|
||||
if t.parent == globalTask {
|
||||
t.OnCancel(about, fn)
|
||||
return
|
||||
}
|
||||
t.onFinishedWg.Add(1)
|
||||
t.OnFinishedMu.Lock()
|
||||
defer t.OnFinishedMu.Unlock()
|
||||
|
||||
if t.OnFinishedFuncs == nil {
|
||||
onCompTask := GlobalTask(t.name + " > OnFinished")
|
||||
go t.runAllOnFinished(onCompTask)
|
||||
}
|
||||
var file string
|
||||
var line int
|
||||
if common.IsTrace {
|
||||
_, file, line, _ = runtime.Caller(1)
|
||||
}
|
||||
go func() {
|
||||
idx := len(t.OnFinishedFuncs)
|
||||
wrapped := func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err)
|
||||
logrus.Errorf("panic in %s > OnFinished[%d]: %q\nline %s:%d\n%v", t.name, idx, about, file, line, err)
|
||||
}
|
||||
}()
|
||||
defer t.onCompleteWg.Done()
|
||||
t.subTasksWg.Wait()
|
||||
fn()
|
||||
logrus.Tracef("line %s:%d\n%s > OnFinished[%d] done: %s", file, line, t.name, idx, about)
|
||||
}
|
||||
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
|
||||
}
|
||||
|
||||
func (t *task) OnCancel(about string, fn func()) {
|
||||
onCompTask := GlobalTask(t.name + " > OnFinished")
|
||||
go func() {
|
||||
<-t.ctx.Done()
|
||||
fn()
|
||||
logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about)
|
||||
t.cancel(nil) // ensure resources are released
|
||||
onCompTask.Finish("done")
|
||||
logrus.Tracef("%s > onCancel done: %s", t.name, about)
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *task) Finish(reason string) {
|
||||
t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason))
|
||||
t.Wait()
|
||||
func (t *task) Finish(reason any) {
|
||||
var format string
|
||||
switch reason.(type) {
|
||||
case error:
|
||||
format = "%w"
|
||||
case string, fmt.Stringer:
|
||||
format = "%s"
|
||||
default:
|
||||
format = "%v"
|
||||
}
|
||||
t.finishOnce.Do(func() {
|
||||
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
|
||||
t.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *task) Subtask(format string, args ...any) Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
func (t *task) Subtask(name string) Task {
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
return t.newSubTask(ctx, cancel, format)
|
||||
return t.newSubTask(ctx, cancel, name)
|
||||
}
|
||||
|
||||
func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task {
|
||||
parent := t
|
||||
if common.IsTrace {
|
||||
name = parent.name + " > " + name
|
||||
}
|
||||
subtask := &task{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
|
@ -222,10 +280,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
|
|||
if ok {
|
||||
subtask.line = fmt.Sprintf("%s:%d", file, line)
|
||||
}
|
||||
logrus.Tracef("line %s\ntask %q started", subtask.line, name)
|
||||
logrus.Tracef("line %s\n%s started", subtask.line, name)
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
logrus.Tracef("task %q finished", subtask.Name())
|
||||
logrus.Tracef("%s finished: %s", subtask.Name(), subtask.FinishCause())
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
|
@ -237,11 +295,9 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
|
|||
}
|
||||
|
||||
func (t *task) Wait() {
|
||||
t.subTasksWg.Wait()
|
||||
if t != globalTask {
|
||||
<-t.ctx.Done()
|
||||
}
|
||||
t.onCompleteWg.Wait()
|
||||
<-t.ctx.Done()
|
||||
t.WaitSubTasks()
|
||||
t.onFinishedWg.Wait()
|
||||
}
|
||||
|
||||
func (t *task) WaitSubTasks() {
|
||||
|
@ -270,9 +326,9 @@ func (t *task) tree(prefix ...string) string {
|
|||
}
|
||||
if t.line != "" {
|
||||
sb.WriteString("line " + t.line + "\n")
|
||||
}
|
||||
if len(pre) > 0 {
|
||||
sb.WriteString(pre + "- ")
|
||||
if len(pre) > 0 {
|
||||
sb.WriteString(pre + "- ")
|
||||
}
|
||||
}
|
||||
sb.WriteString(t.Name() + "\n")
|
||||
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
|
||||
|
@ -299,7 +355,8 @@ func (t *task) tree(prefix ...string) string {
|
|||
// only.
|
||||
func (t *task) serialize() map[string]any {
|
||||
m := make(map[string]any)
|
||||
m["name"] = t.name
|
||||
parts := strings.Split(t.name, ">")
|
||||
m["name"] = strings.TrimSpace(parts[len(parts)-1])
|
||||
if t.line != "" {
|
||||
m["line"] = t.line
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func TestTaskCancellation(t *testing.T) {
|
|||
err := subTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(subTask.Context())
|
||||
ExpectError(t, ErrTaskCancelled, cause)
|
||||
ExpectError(t, ErrTaskCanceled, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ func TestOnComplete(t *testing.T) {
|
|||
|
||||
task := GlobalTask("test")
|
||||
var value atomic.Int32
|
||||
task.OnComplete("set value", func() {
|
||||
task.OnFinished("set value", func() {
|
||||
value.Store(1234)
|
||||
})
|
||||
task.Finish("done")
|
||||
|
@ -90,10 +90,10 @@ func TestGlobalContextWait(t *testing.T) {
|
|||
|
||||
subTask1 := rootTask.Subtask("subtask1")
|
||||
subTask2 := rootTask.Subtask("subtask2")
|
||||
subTask1.OnComplete("set finished", func() {
|
||||
subTask1.OnFinished("set finished", func() {
|
||||
finished1 = true
|
||||
})
|
||||
subTask2.OnComplete("set finished", func() {
|
||||
subTask2.OnFinished("set finished", func() {
|
||||
finished2 = true
|
||||
})
|
||||
|
||||
|
@ -117,8 +117,8 @@ func TestGlobalContextWait(t *testing.T) {
|
|||
ExpectTrue(t, finished1)
|
||||
ExpectTrue(t, finished2)
|
||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
||||
ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context()))
|
||||
ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context()))
|
||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
|
||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
|
||||
}
|
||||
|
||||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
||||
|
|
|
@ -160,7 +160,7 @@ func (m Map[KT, VT]) Has(k KT) bool {
|
|||
// Returns:
|
||||
//
|
||||
// error: if the unmarshaling fails
|
||||
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
|
||||
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.Error {
|
||||
if m.Size() != 0 {
|
||||
return E.FailedWhy("unmarshal from yaml", "map is not empty")
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error {
|
|||
return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src})
|
||||
}
|
||||
|
||||
func LoadJSON[T any](path string, pointer *T) E.NestedError {
|
||||
func LoadJSON[T any](path string, pointer *T) E.Error {
|
||||
data, err := E.Check(os.ReadFile(path))
|
||||
if err.HasError() {
|
||||
return err
|
||||
|
@ -160,7 +160,7 @@ func LoadJSON[T any](path string, pointer *T) E.NestedError {
|
|||
return E.From(json.Unmarshal(data, pointer))
|
||||
}
|
||||
|
||||
func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
|
||||
func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.Error {
|
||||
data, err := E.Check(json.Marshal(pointer))
|
||||
if err.HasError() {
|
||||
return err
|
||||
|
|
|
@ -19,11 +19,11 @@ import (
|
|||
type (
|
||||
SerializedObject = map[string]any
|
||||
Converter interface {
|
||||
ConvertFrom(value any) E.NestedError
|
||||
ConvertFrom(value any) E.Error
|
||||
}
|
||||
)
|
||||
|
||||
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
|
||||
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error {
|
||||
var i any
|
||||
|
||||
err := yaml.Unmarshal(data, &i)
|
||||
|
@ -66,7 +66,7 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
|
|||
// Returns:
|
||||
// - result: The resulting map[string]any representation of the data.
|
||||
// - error: An error if the data type is unsupported or if there is an error during conversion.
|
||||
func Serialize(data any) (SerializedObject, E.NestedError) {
|
||||
func Serialize(data any) (SerializedObject, E.Error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Use reflection to inspect the data type
|
||||
|
@ -137,7 +137,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
|
|||
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
|
||||
//
|
||||
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
||||
func Deserialize(src SerializedObject, dst any) E.NestedError {
|
||||
func Deserialize(src SerializedObject, dst any) E.Error {
|
||||
if src == nil {
|
||||
return E.Invalid("src", "nil")
|
||||
}
|
||||
|
@ -210,7 +210,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
|
|||
//
|
||||
// Returns:
|
||||
// - error: the error occurred during conversion, or nil if no error occurred.
|
||||
func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
|
||||
func Convert(src reflect.Value, dst reflect.Value) E.Error {
|
||||
srcT := src.Type()
|
||||
dstT := dst.Type()
|
||||
|
||||
|
@ -277,7 +277,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
|
|||
return converter.ConvertFrom(src.Interface())
|
||||
}
|
||||
|
||||
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) {
|
||||
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) {
|
||||
convertible = true
|
||||
if dst.Kind() == reflect.Ptr {
|
||||
if dst.IsNil() {
|
||||
|
@ -379,7 +379,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N
|
|||
return true, Convert(reflect.ValueOf(tmp), dst)
|
||||
}
|
||||
|
||||
func DeserializeJSON(j map[string]string, target any) E.NestedError {
|
||||
func DeserializeJSON(j map[string]string, target any) E.Error {
|
||||
data, err := E.Check(json.Marshal(j))
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -113,7 +113,7 @@ type testType struct {
|
|||
bar string
|
||||
}
|
||||
|
||||
func (c *testType) ConvertFrom(v any) E.NestedError {
|
||||
func (c *testType) ConvertFrom(v any) E.Error {
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
c.bar = v
|
||||
|
|
|
@ -98,7 +98,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
|
|||
return got.(T)
|
||||
}
|
||||
|
||||
func Must[T any](v T, err E.NestedError) T {
|
||||
func Must[T any](v T, err E.Error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type DirWatcher struct {
|
|||
mu sync.Mutex
|
||||
|
||||
eventCh chan Event
|
||||
errCh chan E.NestedError
|
||||
errCh chan E.Error
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
@ -48,14 +48,14 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher {
|
|||
w: w,
|
||||
fwMap: F.NewMapOf[string, *fileWatcher](),
|
||||
eventCh: make(chan Event),
|
||||
errCh: make(chan E.NestedError),
|
||||
errCh: make(chan E.Error),
|
||||
ctx: ctx,
|
||||
}
|
||||
go helper.start()
|
||||
return helper
|
||||
}
|
||||
|
||||
func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) {
|
||||
func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.Error) {
|
||||
return h.eventCh, h.errCh
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ func (h *DirWatcher) Add(relPath string) Watcher {
|
|||
s = &fileWatcher{
|
||||
relPath: relPath,
|
||||
eventCh: make(chan Event),
|
||||
errCh: make(chan E.NestedError),
|
||||
errCh: make(chan E.Error),
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
|
|
|
@ -36,6 +36,14 @@ var (
|
|||
|
||||
NewDockerFilter = filters.NewArgs
|
||||
|
||||
optionsDefault = DockerListOptions{Filters: NewDockerFilter(
|
||||
DockerFilterContainer,
|
||||
DockerFilterStart,
|
||||
// DockerFilterStop,
|
||||
DockerFilterDie,
|
||||
DockerFilterDestroy,
|
||||
)}
|
||||
|
||||
dockerWatcherRetryInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
|
@ -61,13 +69,13 @@ func NewDockerWatcherWithClient(client D.Client) DockerWatcher {
|
|||
WithField("host", client.DaemonHost()))}
|
||||
}
|
||||
|
||||
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
|
||||
return w.EventsWithOptions(ctx, optionsWatchAll)
|
||||
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) {
|
||||
return w.EventsWithOptions(ctx, optionsDefault)
|
||||
}
|
||||
|
||||
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) {
|
||||
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.Error) {
|
||||
eventCh := make(chan Event)
|
||||
errCh := make(chan E.NestedError)
|
||||
errCh := make(chan E.Error)
|
||||
|
||||
go func() {
|
||||
defer close(eventCh)
|
||||
|
@ -80,7 +88,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
|
|||
}()
|
||||
|
||||
if !w.client.Connected() {
|
||||
var err E.NestedError
|
||||
var err E.Error
|
||||
attempts := 0
|
||||
for {
|
||||
w.client, err = D.ConnectClient(w.host)
|
||||
|
@ -141,11 +149,3 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
|
|||
|
||||
return eventCh, errCh
|
||||
}
|
||||
|
||||
var optionsWatchAll = DockerListOptions{Filters: NewDockerFilter(
|
||||
DockerFilterContainer,
|
||||
DockerFilterStart,
|
||||
// DockerFilterStop,
|
||||
DockerFilterDie,
|
||||
DockerFilterDestroy,
|
||||
)}
|
||||
|
|
|
@ -3,20 +3,22 @@ package events
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type (
|
||||
EventQueue struct {
|
||||
task task.Task
|
||||
queue []Event
|
||||
ticker *time.Ticker
|
||||
onFlush OnFlushFunc
|
||||
onError OnErrorFunc
|
||||
task task.Task
|
||||
queue []Event
|
||||
ticker *time.Ticker
|
||||
flushInterval time.Duration
|
||||
onFlush OnFlushFunc
|
||||
onError OnErrorFunc
|
||||
}
|
||||
OnFlushFunc = func(flushTask task.Task, events []Event)
|
||||
OnErrorFunc = func(err E.NestedError)
|
||||
OnErrorFunc = func(err E.Error)
|
||||
)
|
||||
|
||||
const eventQueueCapacity = 10
|
||||
|
@ -35,40 +37,45 @@ const eventQueueCapacity = 10
|
|||
// flushTask.Finish must be called after the flush is done,
|
||||
// but the onFlush function can return earlier (e.g. run in another goroutine).
|
||||
//
|
||||
// If task is cancelled before the flushInterval is reached, the events in queue will be discarded.
|
||||
// If task is canceled before the flushInterval is reached, the events in queue will be discarded.
|
||||
func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
|
||||
return &EventQueue{
|
||||
task: parent.Subtask("event queue"),
|
||||
queue: make([]Event, 0, eventQueueCapacity),
|
||||
ticker: time.NewTicker(flushInterval),
|
||||
onFlush: onFlush,
|
||||
onError: onError,
|
||||
task: parent.Subtask("event queue"),
|
||||
queue: make([]Event, 0, eventQueueCapacity),
|
||||
ticker: time.NewTicker(flushInterval),
|
||||
flushInterval: flushInterval,
|
||||
onFlush: onFlush,
|
||||
onError: onError,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) {
|
||||
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
|
||||
go func() {
|
||||
defer e.ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-e.task.Context().Done():
|
||||
e.task.Finish(e.task.FinishCause().Error())
|
||||
return
|
||||
case <-e.ticker.C:
|
||||
if len(e.queue) > 0 {
|
||||
flushTask := e.task.Subtask("flush events")
|
||||
queue := e.queue
|
||||
e.queue = make([]Event, 0, eventQueueCapacity)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name()))
|
||||
}
|
||||
if !common.IsDebug {
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name()))
|
||||
}
|
||||
}()
|
||||
e.onFlush(flushTask, queue)
|
||||
}()
|
||||
e.onFlush(flushTask, queue)
|
||||
}()
|
||||
} else {
|
||||
go e.onFlush(flushTask, queue)
|
||||
}
|
||||
flushTask.Wait()
|
||||
}
|
||||
e.ticker.Reset(e.flushInterval)
|
||||
case event, ok := <-eventCh:
|
||||
e.queue = append(e.queue, event)
|
||||
if !ok {
|
||||
|
|
|
@ -9,9 +9,9 @@ import (
|
|||
type fileWatcher struct {
|
||||
relPath string
|
||||
eventCh chan Event
|
||||
errCh chan E.NestedError
|
||||
errCh chan E.Error
|
||||
}
|
||||
|
||||
func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
|
||||
func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) {
|
||||
return fw.eventCh, fw.errCh
|
||||
}
|
||||
|
|
28
internal/watcher/health/health_checker.go
Normal file
28
internal/watcher/health/health_checker.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package health
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type (
|
||||
HealthMonitor interface {
|
||||
task.TaskStarter
|
||||
task.TaskFinisher
|
||||
fmt.Stringer
|
||||
json.Marshaler
|
||||
Status() Status
|
||||
Uptime() time.Duration
|
||||
Name() string
|
||||
}
|
||||
HealthChecker interface {
|
||||
CheckHealth() (healthy bool, detail string, err error)
|
||||
URL() types.URL
|
||||
Config() *HealthCheckConfig
|
||||
UpdateURL(url types.URL)
|
||||
}
|
||||
)
|
|
@ -2,9 +2,7 @@ package health
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -15,21 +13,6 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
HealthMonitor interface {
|
||||
task.TaskStarter
|
||||
task.TaskFinisher
|
||||
fmt.Stringer
|
||||
json.Marshaler
|
||||
Status() Status
|
||||
Uptime() time.Duration
|
||||
Name() string
|
||||
}
|
||||
HealthChecker interface {
|
||||
CheckHealth() (healthy bool, detail string, err error)
|
||||
URL() types.URL
|
||||
Config() *HealthCheckConfig
|
||||
UpdateURL(url types.URL)
|
||||
}
|
||||
HealthCheckFunc func() (healthy bool, detail string, err error)
|
||||
monitor struct {
|
||||
service string
|
||||
|
@ -71,7 +54,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
|
|||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
|
||||
func (mon *monitor) Start(routeSubtask task.Task) E.Error {
|
||||
mon.service = routeSubtask.Parent().Name()
|
||||
mon.task = routeSubtask
|
||||
|
||||
|
@ -84,7 +67,6 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
|
|||
if mon.status.Load() != StatusError {
|
||||
mon.status.Store(StatusUnknown)
|
||||
}
|
||||
mon.task.Finish(mon.task.FinishCause().Error())
|
||||
}()
|
||||
|
||||
if err := mon.checkUpdateHealth(); err != nil {
|
||||
|
@ -115,7 +97,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
|
|||
}
|
||||
|
||||
// Finish implements task.TaskFinisher.
|
||||
func (mon *monitor) Finish(reason string) {
|
||||
func (mon *monitor) Finish(reason any) {
|
||||
mon.task.Finish(reason)
|
||||
}
|
||||
|
||||
|
@ -169,10 +151,10 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
|
|||
}).MarshalJSON()
|
||||
}
|
||||
|
||||
func (mon *monitor) checkUpdateHealth() E.NestedError {
|
||||
func (mon *monitor) checkUpdateHealth() E.Error {
|
||||
healthy, detail, err := mon.checkHealth()
|
||||
if err != nil {
|
||||
defer mon.task.Finish(err.Error())
|
||||
defer mon.task.Finish(err)
|
||||
mon.status.Store(StatusError)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
return E.Failure("check health").With(err)
|
||||
|
|
|
@ -10,5 +10,5 @@ import (
|
|||
type Event = events.Event
|
||||
|
||||
type Watcher interface {
|
||||
Events(ctx context.Context) (<-chan Event, <-chan E.NestedError)
|
||||
Events(ctx context.Context) (<-chan Event, <-chan E.Error)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue