fixed loadbalancer with idlewatcher, fixed reload issue

This commit is contained in:
yusing 2024-10-20 09:46:02 +08:00
parent 01ffe0d97c
commit a278711421
78 changed files with 906 additions and 609 deletions

1
.gitignore vendored
View file

@ -24,3 +24,4 @@ todo.md
.aider*
mtrace.json
.env
test.Dockerfile

View file

@ -55,7 +55,7 @@ repush:
git push gitlab dev --force
rapid-crash:
sudo docker run --restart=always --name test_crash debian:bookworm-slim /bin/cat &&\
sudo docker run --restart=always --name test_crash -p 80 debian:bookworm-slim /bin/cat &&\
sleep 3 &&\
sudo docker rm -f test_crash
@ -64,4 +64,4 @@ debug-list-containers:
ci-test:
mkdir -p /tmp/artifacts
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"

View file

@ -84,7 +84,7 @@ func main() {
middleware.LoadComposeFiles()
var cfg *config.Config
var err E.NestedError
var err E.Error
if cfg, err = config.Load(); err != nil {
logrus.Warn(err)
}

View file

@ -39,7 +39,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return
}
var validateErr E.NestedError
var validateErr E.Error
if filename == common.ConfigFileName {
validateErr = config.Validate(content)
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {

View file

@ -13,7 +13,7 @@ import (
"github.com/yusing/go-proxy/internal/net/http/middleware"
)
func ReloadServer() E.NestedError {
func ReloadServer() E.Error {
resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil)
if err != nil {
return E.From(err)
@ -34,7 +34,7 @@ func ReloadServer() E.NestedError {
return nil
}
func List[T any](what string) (_ T, outErr E.NestedError) {
func List[T any](what string) (_ T, outErr E.Error) {
resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, what))
if err != nil {
outErr = E.From(err)
@ -54,14 +54,14 @@ func List[T any](what string) (_ T, outErr E.NestedError) {
return res, nil
}
func ListRoutes() (map[string]map[string]any, E.NestedError) {
func ListRoutes() (map[string]map[string]any, E.Error) {
return List[map[string]map[string]any](v1.ListRoutes)
}
func ListMiddlewareTraces() (middleware.Traces, E.NestedError) {
func ListMiddlewareTraces() (middleware.Traces, E.Error) {
return List[middleware.Traces](v1.ListMiddlewareTraces)
}
func DebugListTasks() (map[string]any, E.NestedError) {
func DebugListTasks() (map[string]any, E.Error) {
return List[map[string]any](v1.ListTasks)
}

View file

@ -27,7 +27,7 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
return (*Config)(cfg)
}
func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) {
func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
b := E.NewBuilder("unable to initialize autocert")
defer b.To(&res)

View file

@ -29,7 +29,7 @@ type (
tlsCert *tls.Certificate
certExpiries CertExpiries
}
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError)
ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.Error)
CertExpiries map[string]time.Time
)
@ -57,7 +57,7 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries
}
func (p *Provider) ObtainCert() (res E.NestedError) {
func (p *Provider) ObtainCert() (res E.Error) {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
@ -112,7 +112,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) {
return nil
}
func (p *Provider) LoadCert() E.NestedError {
func (p *Provider) LoadCert() E.Error {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.HasError() {
return err
@ -158,7 +158,7 @@ func (p *Provider) ScheduleRenewal() {
}()
}
func (p *Provider) initClient() E.NestedError {
func (p *Provider) initClient() E.Error {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.FailWith("create lego client", err)
@ -178,7 +178,7 @@ func (p *Provider) initClient() E.NestedError {
return nil
}
func (p *Provider) registerACME() E.NestedError {
func (p *Provider) registerACME() E.Error {
if p.user.Registration != nil {
return nil
}
@ -191,7 +191,7 @@ func (p *Provider) registerACME() E.NestedError {
return nil
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
func (p *Provider) saveCert(cert *certificate.Resource) E.Error {
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
@ -239,7 +239,7 @@ func (p *Provider) certState() CertState {
return CertStateValid
}
func (p *Provider) renewIfNeeded() E.NestedError {
func (p *Provider) renewIfNeeded() E.Error {
if p.cfg.Provider == ProviderLocal {
return nil
}
@ -259,7 +259,7 @@ func (p *Provider) renewIfNeeded() E.NestedError {
return nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
@ -281,7 +281,7 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT,
newProvider func(*CT) (PT, error),
) ProviderGenerator {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg()
err := U.Deserialize(opt, cfg)
if err.HasError() {

View file

@ -6,7 +6,7 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
func (p *Provider) Setup() (err E.NestedError) {
func (p *Provider) Setup() (err E.Error) {
if err = p.LoadCert(); err != nil {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err

View file

@ -47,3 +47,5 @@ const (
StopTimeoutDefault = "10s"
StopMethodDefault = "stop"
)
const HeaderCheckRedirect = "X-Goproxy-Check-Redirect"

View file

@ -55,7 +55,7 @@ func newConfig() *Config {
}
}
func Load() (*Config, E.NestedError) {
func Load() (*Config, E.Error) {
if instance != nil {
return instance, nil
}
@ -64,7 +64,7 @@ func Load() (*Config, E.NestedError) {
return instance, instance.load()
}
func Validate(data []byte) E.NestedError {
func Validate(data []byte) E.Error {
return U.ValidateYaml(U.GetSchema(common.ConfigSchemaPath), data)
}
@ -78,7 +78,7 @@ func WatchChanges() {
task,
configEventFlushInterval,
OnConfigChange,
func(err E.NestedError) {
func(err E.Error) {
logger.Error(err)
},
)
@ -104,7 +104,7 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) {
}
}
func Reload() E.NestedError {
func Reload() E.Error {
// avoid race between config change and API reload request
reloadMu.Lock()
defer reloadMu.Unlock()
@ -139,7 +139,7 @@ func (cfg *Config) Task() task.Task {
func (cfg *Config) StartProxyProviders() {
b := E.NewBuilder("errors starting providers")
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
b.Add(p.Start(cfg.task.Subtask("provider %s", p.GetName())))
b.Add(p.Start(cfg.task.Subtask(p.String())))
})
if b.HasError() {
@ -147,7 +147,7 @@ func (cfg *Config) StartProxyProviders() {
}
}
func (cfg *Config) load() (res E.NestedError) {
func (cfg *Config) load() (res E.Error) {
b := E.NewBuilder("errors loading config")
defer b.To(&res)
@ -182,7 +182,7 @@ func (cfg *Config) load() (res E.NestedError) {
return
}
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.NestedError) {
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
if cfg.autocertProvider != nil {
return
}
@ -197,7 +197,7 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested
return
}
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.NestedError) {
func (cfg *Config) loadProviders(providers *types.ProxyProviders) (outErr E.Error) {
subtask := cfg.task.Subtask("load providers")
defer subtask.Finish("done")

View file

@ -37,7 +37,7 @@ var (
)
func init() {
task.GlobalTask("close docker clients").OnComplete("", func() {
task.GlobalTask("close docker clients").OnFinished("", func() {
clientMap.RangeAllParallel(func(_ string, c Client) {
if c.Connected() {
c.Client.Close()
@ -70,7 +70,7 @@ func (c *SharedClient) Close() error {
// Returns:
// - Client: the Docker client connection.
// - error: an error if the connection failed.
func ConnectClient(host string) (Client, E.NestedError) {
func ConnectClient(host string) (Client, E.Error) {
clientMapMu.Lock()
defer clientMapMu.Unlock()

View file

@ -6,6 +6,8 @@ import (
"fmt"
"strings"
"text/template"
"github.com/yusing/go-proxy/internal/common"
)
type templateData struct {
@ -18,17 +20,15 @@ type templateData struct {
var loadingPage []byte
var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage)))
const headerCheckRedirect = "X-Goproxy-Check-Redirect"
func (w *Watcher) makeLoadingPageBody() []byte {
msg := fmt.Sprintf("%s is starting...", w.ContainerName)
data := new(templateData)
data.CheckRedirectHeader = headerCheckRedirect
data.CheckRedirectHeader = common.HeaderCheckRedirect
data.Title = w.ContainerName
data.Message = strings.ReplaceAll(msg, " ", " ")
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(headerCheckRedirect)))
buf := bytes.NewBuffer(make([]byte, len(loadingPage)+len(data.Title)+len(data.Message)+len(common.HeaderCheckRedirect)))
err := loadingPageTmpl.Execute(buf, data)
if err != nil { // should never happen in production
panic(err)

View file

@ -1,4 +1,4 @@
package idlewatcher
package types
import (
"time"
@ -30,7 +30,7 @@ const (
StopMethodKill StopMethod = "kill"
)
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
if cont == nil {
return nil, nil
}
@ -80,7 +80,7 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.NestedError) {
}, nil
}
func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
func validateDurationPostitive(value string) (time.Duration, E.Error) {
d, err := time.ParseDuration(value)
if err != nil {
return 0, E.Invalid("duration", value).With(err)
@ -91,7 +91,7 @@ func validateDurationPostitive(value string) (time.Duration, E.NestedError) {
return d, nil
}
func validateSignal(s string) (Signal, E.NestedError) {
func validateSignal(s string) (Signal, E.Error) {
switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT":
@ -101,7 +101,7 @@ func validateSignal(s string) (Signal, E.NestedError) {
return "", E.Invalid("signal", s)
}
func validateStopMethod(s string) (StopMethod, E.NestedError) {
func validateStopMethod(s string) (StopMethod, E.Error) {
sm := StopMethod(s)
switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill:

View file

@ -0,0 +1,14 @@
package types
import (
"net/http"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Waker interface {
health.HealthMonitor
http.Handler
net.Stream
}

View file

@ -1,10 +1,10 @@
package idlewatcher
import (
"net/http"
"sync/atomic"
"time"
. "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
net "github.com/yusing/go-proxy/internal/net/types"
@ -14,12 +14,6 @@ import (
"github.com/yusing/go-proxy/internal/watcher/health"
)
type Waker interface {
health.HealthMonitor
http.Handler
net.Stream
}
type waker struct {
_ U.NoCopy
@ -37,7 +31,7 @@ const (
// TODO: support stream
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.NestedError) {
func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) {
hcCfg := entry.HealthCheckConfig()
hcCfg.Timeout = idleWakerCheckTimeout
@ -62,24 +56,26 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
}
// lifetime should follow route provider
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.NestedError) {
func NewHTTPWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) {
return newWaker(providerSubTask, entry, rp, nil)
}
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.NestedError) {
func NewStreamWaker(providerSubTask task.Task, entry entry.Entry, stream net.Stream) (Waker, E.Error) {
return newWaker(providerSubTask, entry, nil, stream)
}
// Start implements health.HealthMonitor.
func (w *Watcher) Start(routeSubTask task.Task) E.NestedError {
w.task.OnComplete("stop route", func() {
routeSubTask.Parent().Finish("watcher stopped")
func (w *Watcher) Start(routeSubTask task.Task) E.Error {
routeSubTask.Finish("ignored")
w.task.OnCancel("stop route", func() {
routeSubTask.Parent().Finish(w.task.FinishCause())
})
return nil
}
// Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason string) {}
func (w *Watcher) Finish(reason any) {
}
// Name implements health.HealthMonitor.
func (w *Watcher) Name() string {
@ -109,6 +105,7 @@ func (w *Watcher) Status() health.Status {
healthy, _, err := w.hc.CheckHealth()
switch {
case err != nil:
w.ready.Store(false)
return health.StatusError
case healthy:
w.ready.Store(true)

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health"
@ -37,7 +38,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
isCheckRedirect := r.Header.Get(headerCheckRedirect) != ""
isCheckRedirect := r.Header.Get(common.HeaderCheckRedirect) != ""
if !isCheckRedirect && acceptHTML {
// Send a loading response to the client
body := w.makeLoadingPageBody()
@ -56,14 +57,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel()
checkCancelled := func() bool {
checkCanceled := func() bool {
select {
case <-w.task.Context().Done():
w.l.Debugf("wake cancelled: %s", context.Cause(w.task.Context()))
w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context()))
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
case <-ctx.Done():
w.l.Debugf("wake cancelled: %s", context.Cause(ctx))
w.l.Debugf("wake canceled: %s", context.Cause(ctx))
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout)
return true
default:
@ -71,7 +72,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
}
}
if checkCancelled() {
if checkCanceled() {
return false
}
@ -84,14 +85,14 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
}
for {
if checkCancelled() {
if checkCanceled() {
return false
}
if w.Status() == health.StatusHealthy {
w.resetIdleTimer()
if isCheckRedirect {
logrus.Debugf("container %s is ready, redirecting...", w.String())
logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL())
rw.WriteHeader(http.StatusOK)
return
}

View file

@ -8,44 +8,47 @@ import (
"time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health"
)
// Setup implements types.Stream.
func (w *Watcher) Addr() net.Addr {
return w.stream.Addr()
}
func (w *Watcher) Setup() error {
return w.stream.Setup()
}
// Accept implements types.Stream.
func (w *Watcher) Accept() (conn types.StreamConn, err error) {
func (w *Watcher) Accept() (conn net.Conn, err error) {
conn, err = w.stream.Accept()
// timeout means no connection is accepted
var nErr *net.OpError
ok := errors.As(err, &nErr)
if ok && nErr.Timeout() {
if err != nil {
logrus.Errorf("accept failed with error: %s", err)
return
}
if err := w.wakeFromStream(); err != nil {
return nil, err
w.l.Error(err)
}
return w.stream.Accept()
}
// CloseListeners implements types.Stream.
func (w *Watcher) CloseListeners() {
w.stream.CloseListeners()
return
}
// Handle implements types.Stream.
func (w *Watcher) Handle(conn types.StreamConn) error {
func (w *Watcher) Handle(conn net.Conn) error {
if err := w.wakeFromStream(); err != nil {
return err
}
return w.stream.Handle(conn)
}
// Close implements types.Stream.
func (w *Watcher) Close() error {
return w.stream.Close()
}
func (w *Watcher) wakeFromStream() error {
w.resetIdleTimer()
// pass through if container is already ready
if w.ready.Load() {
return nil
@ -66,11 +69,11 @@ func (w *Watcher) wakeFromStream() error {
select {
case <-w.task.Context().Done():
cause := w.task.FinishCause()
w.l.Debugf("wake cancelled: %s", cause)
w.l.Debugf("wake canceled: %s", cause)
return cause
case <-ctx.Done():
cause := context.Cause(ctx)
w.l.Debugf("wake cancelled: %s", cause)
w.l.Debugf("wake canceled: %s", cause)
return cause
default:
}

View file

@ -10,7 +10,7 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task"
@ -49,7 +49,7 @@ var (
const dockerReqTimeout = 3 * time.Second
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.NestedError) {
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) {
failure := E.Failure("idle_watcher register")
cfg := entry.IdlewatcherConfig()
@ -66,6 +66,7 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
w.Config = cfg
w.waker = waker
w.resetIdleTimer()
providerSubtask.Finish("used existing watcher")
return w, nil
}
@ -88,13 +89,11 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
go func() {
cause := w.watchUntilDestroy()
watcherMapMu.Lock()
watcherMap.Delete(w.ContainerID)
watcherMapMu.Unlock()
w.ticker.Stop()
w.client.Close()
w.task.Finish(cause.Error())
w.task.Finish(cause)
}()
return w, nil
@ -146,7 +145,7 @@ func (w *Watcher) wakeIfStopped() error {
return err
}
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
ctx, cancel := context.WithTimeout(w.task.Context(), w.WakeTimeout)
defer cancel()
// !Hard coded here since theres no constants from Docker API
@ -175,7 +174,7 @@ func (w *Watcher) getStopCallback() StopCallback {
panic("should not reach here")
}
return func() error {
ctx, cancel := context.WithTimeout(w.task.Context(), dockerReqTimeout)
ctx, cancel := context.WithTimeout(w.task.Context(), time.Duration(w.StopTimeout)*time.Second)
defer cancel()
return cb(ctx)
}
@ -186,8 +185,8 @@ func (w *Watcher) resetIdleTimer() {
w.ticker.Reset(w.IdleTimeout)
}
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.NestedError) {
eventTask = w.task.Subtask("watcher for %s", w.ContainerID)
func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) {
eventTask = w.task.Subtask("docker event watcher")
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter(
W.DockerFilterContainer,
@ -218,13 +217,12 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
func (w *Watcher) watchUntilDestroy() error {
dockerWatcher := W.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
defer eventTask.Finish("stopped")
for {
select {
case <-w.task.Context().Done():
cause := context.Cause(w.task.Context())
w.l.Debugf("watcher stopped by context done: %s", cause)
return cause
return w.task.FinishCause()
case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err))

View file

@ -8,7 +8,7 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
func Inspect(dockerHost string, containerID string) (*Container, E.NestedError) {
func Inspect(dockerHost string, containerID string) (*Container, E.Error) {
client, err := ConnectClient(dockerHost)
defer client.Close()
@ -19,7 +19,7 @@ func Inspect(dockerHost string, containerID string) (*Container, E.NestedError)
return client.Inspect(containerID)
}
func (c Client) Inspect(containerID string) (*Container, E.NestedError) {
func (c Client) Inspect(containerID string) (*Container, E.Error) {
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel()

View file

@ -39,7 +39,7 @@ func (l *Label) String() string {
//
// Returns:
// - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
func ApplyLabel[T any](obj *T, l *Label) E.Error {
if obj == nil {
return E.Invalid("nil object", l)
}
@ -81,7 +81,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError {
}
}
func ParseLabel(label string, value string) (*Label, E.NestedError) {
func ParseLabel(label string, value string) (*Label, E.Error) {
parts := strings.Split(label, ".")
if len(parts) < 2 {

View file

@ -11,11 +11,6 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
type ClientInfo struct {
Client Client
Containers []types.Container
}
var listOptions = container.ListOptions{
// created|restarting|running|removing|paused|exited|dead
// Filters: filters.NewArgs(
@ -28,28 +23,21 @@ var listOptions = container.ListOptions{
All: true,
}
func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) {
func ListContainers(clientHost string) ([]types.Container, E.Error) {
dockerClient, err := ConnectClient(clientHost)
if err.HasError() {
return nil, E.FailWith("connect to docker", err)
}
defer dockerClient.Close()
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker client connection timeout"))
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
defer cancel()
var containers []types.Container
if getContainer {
containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions))
if err.HasError() {
return nil, E.FailWith("list containers", err)
}
return &ClientInfo{
Client: dockerClient,
Containers: containers,
}, nil
return containers, nil
}
func IsErrConnectionFailed(err error) bool {

View file

@ -12,7 +12,7 @@ type Builder struct {
type builder struct {
message string
errors []NestedError
errors []Error
sync.Mutex
}
@ -28,7 +28,7 @@ func NewBuilder(format string, args ...any) Builder {
// adding nil is no-op,
//
// flatten is a boolean flag to flatten the NestedError.
func (b Builder) Add(err NestedError, flatten ...bool) {
func (b Builder) Add(err Error, flatten ...bool) {
if err != nil {
b.Lock()
if len(flatten) > 0 && flatten[0] {
@ -54,7 +54,7 @@ func (b Builder) Addf(format string, args ...any) {
}
}
func (b Builder) AddRange(errs ...NestedError) {
func (b Builder) AddRange(errs ...Error) {
b.Lock()
defer b.Unlock()
for _, err := range errs {
@ -77,14 +77,14 @@ func (b Builder) AddRangeE(errs ...error) {
//
// Returns:
// - NestedError: the built NestedError.
func (b Builder) Build() NestedError {
func (b Builder) Build() Error {
if len(b.errors) == 0 {
return nil
}
return Join(b.message, b.errors...)
}
func (b Builder) To(ptr *NestedError) {
func (b Builder) To(ptr *Error) {
switch {
case ptr == nil:
return

View file

@ -16,7 +16,7 @@ func TestBuilderEmpty(t *testing.T) {
func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf")
var err NestedError
var err Error
for range 3 {
eb.Add(nil)
}
@ -53,7 +53,7 @@ func TestBuilderTo(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Addf("abcd")
var err NestedError
var err Error
eb.To(&err)
got := err.String()
expected := (`error occurred:

View file

@ -8,35 +8,35 @@ import (
)
type (
NestedError = *NestedErrorImpl
NestedErrorImpl struct {
Error = *ErrorImpl
ErrorImpl struct {
subject string
err error
extras []NestedErrorImpl
extras []ErrorImpl
}
JSONNestedError struct {
Subject string `json:"subject"`
Err string `json:"error"`
Extras []JSONNestedError `json:"extras,omitempty"`
ErrorJSONMarshaller struct {
Subject string `json:"subject"`
Err string `json:"error"`
Extras []ErrorJSONMarshaller `json:"extras,omitempty"`
}
)
func From(err error) NestedError {
func From(err error) Error {
if IsNil(err) {
return nil
}
return &NestedErrorImpl{err: err}
return &ErrorImpl{err: err}
}
func FromJSON(data []byte) (NestedError, bool) {
var j JSONNestedError
func FromJSON(data []byte) (Error, bool) {
var j ErrorJSONMarshaller
if err := json.Unmarshal(data, &j); err != nil {
return nil, false
}
if j.Err == "" {
return nil, false
}
extras := make([]NestedErrorImpl, len(j.Extras))
extras := make([]ErrorImpl, len(j.Extras))
for i, e := range j.Extras {
extra, ok := fromJSONObject(e)
if !ok {
@ -44,7 +44,7 @@ func FromJSON(data []byte) (NestedError, bool) {
}
extras[i] = *extra
}
return &NestedErrorImpl{
return &ErrorImpl{
subject: j.Subject,
err: errors.New(j.Err),
extras: extras,
@ -53,12 +53,12 @@ func FromJSON(data []byte) (NestedError, bool) {
// Check is a helper function that
// convert (T, error) to (T, NestedError).
func Check[T any](obj T, err error) (T, NestedError) {
func Check[T any](obj T, err error) (T, Error) {
return obj, From(err)
}
func Join(message string, err ...NestedError) NestedError {
extras := make([]NestedErrorImpl, len(err))
func Join(message string, err ...Error) Error {
extras := make([]ErrorImpl, len(err))
nErr := 0
for i, e := range err {
if e == nil {
@ -70,13 +70,13 @@ func Join(message string, err ...NestedError) NestedError {
if nErr == 0 {
return nil
}
return &NestedErrorImpl{
return &ErrorImpl{
err: errors.New(message),
extras: extras,
}
}
func JoinE(message string, err ...error) NestedError {
func JoinE(message string, err ...error) Error {
b := NewBuilder("%s", message)
for _, e := range err {
b.AddE(e)
@ -92,13 +92,13 @@ func IsNotNil(err error) bool {
return err != nil
}
func (ne NestedError) String() string {
func (ne Error) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne NestedError) Is(err error) bool {
func (ne Error) Is(err error) bool {
if ne == nil {
return err == nil
}
@ -114,18 +114,18 @@ func (ne NestedError) Is(err error) bool {
return false
}
func (ne NestedError) IsNot(err error) bool {
func (ne Error) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne NestedError) Error() error {
func (ne Error) Error() error {
if ne == nil {
return nil
}
return ne.buildError(0, "")
}
func (ne NestedError) With(s any) NestedError {
func (ne Error) With(s any) Error {
if ne == nil {
return ne
}
@ -133,7 +133,7 @@ func (ne NestedError) With(s any) NestedError {
switch ss := s.(type) {
case nil:
return ne
case *NestedErrorImpl:
case *ErrorImpl:
if len(ss.extras) == 1 {
ne.extras = append(ne.extras, ss.extras[0])
return ne
@ -151,11 +151,11 @@ func (ne NestedError) With(s any) NestedError {
return ne.withError(From(errors.New(msg)))
}
func (ne NestedError) Extraf(format string, args ...any) NestedError {
func (ne Error) Extraf(format string, args ...any) Error {
return ne.With(errorf(format, args...))
}
func (ne NestedError) Subject(s any, sep ...string) NestedError {
func (ne Error) Subject(s any, sep ...string) Error {
if ne == nil {
return ne
}
@ -179,26 +179,26 @@ func (ne NestedError) Subject(s any, sep ...string) NestedError {
return ne
}
func (ne NestedError) Subjectf(format string, args ...any) NestedError {
func (ne Error) Subjectf(format string, args ...any) Error {
if ne == nil {
return ne
}
return ne.Subject(fmt.Sprintf(format, args...))
}
func (ne NestedError) JSONObject() JSONNestedError {
extras := make([]JSONNestedError, len(ne.extras))
func (ne Error) JSONObject() ErrorJSONMarshaller {
extras := make([]ErrorJSONMarshaller, len(ne.extras))
for i, e := range ne.extras {
extras[i] = e.JSONObject()
}
return JSONNestedError{
return ErrorJSONMarshaller{
Subject: ne.subject,
Err: ne.err.Error(),
Extras: extras,
}
}
func (ne NestedError) JSON() []byte {
func (ne Error) JSON() []byte {
b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
if err != nil {
panic(err)
@ -206,19 +206,19 @@ func (ne NestedError) JSON() []byte {
return b
}
func (ne NestedError) NoError() bool {
func (ne Error) NoError() bool {
return ne == nil
}
func (ne NestedError) HasError() bool {
func (ne Error) HasError() bool {
return ne != nil
}
func errorf(format string, args ...any) NestedError {
func errorf(format string, args ...any) Error {
return From(fmt.Errorf(format, args...))
}
func fromJSONObject(obj JSONNestedError) (NestedError, bool) {
func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) {
data, err := json.Marshal(obj)
if err != nil {
return nil, false
@ -226,14 +226,14 @@ func fromJSONObject(obj JSONNestedError) (NestedError, bool) {
return FromJSON(data)
}
func (ne NestedError) withError(err NestedError) NestedError {
func (ne Error) withError(err Error) Error {
if ne != nil && err != nil {
ne.extras = append(ne.extras, *err)
}
return ne
}
func (ne NestedError) appendMsg(msg string) NestedError {
func (ne Error) appendMsg(msg string) Error {
if ne == nil {
return nil
}
@ -241,7 +241,7 @@ func (ne NestedError) appendMsg(msg string) NestedError {
return ne
}
func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) {
for range level {
sb.WriteString(" ")
}
@ -266,7 +266,7 @@ func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) {
}
}
func (ne NestedError) buildError(level int, prefix string) error {
func (ne Error) buildError(level int, prefix string) error {
var res error
var sb strings.Builder

View file

@ -26,7 +26,7 @@ func TestErrorIs(t *testing.T) {
}
func TestErrorNestedIs(t *testing.T) {
var err NestedError
var err Error
ExpectTrue(t, err.Is(nil))
err = Failure("some reason")
@ -40,7 +40,7 @@ func TestErrorNestedIs(t *testing.T) {
}
func TestIsNil(t *testing.T) {
var err NestedError
var err Error
ExpectTrue(t, err.Is(nil))
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())

View file

@ -22,62 +22,62 @@ var (
const fmtSubjectWhat = "%w %v: %q"
func Failure(what string) NestedError {
func Failure(what string) Error {
return errorf("%s %w", what, ErrFailure)
}
func FailedWhy(what string, why string) NestedError {
func FailedWhy(what string, why string) Error {
return Failure(what).With(why)
}
func FailWith(what string, err any) NestedError {
func FailWith(what string, err any) Error {
return Failure(what).With(err)
}
func Invalid(subject, what any) NestedError {
func Invalid(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) NestedError {
func Unsupported(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func Unexpected(subject, what any) NestedError {
func Unexpected(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func UnexpectedError(err error) NestedError {
func UnexpectedError(err error) Error {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) NestedError {
func NotExist(subject, what any) Error {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func Missing(subject any) NestedError {
func Missing(subject any) Error {
return errorf("%w %v", ErrMissing, subject)
}
func Duplicated(subject, what any) NestedError {
func Duplicated(subject, what any) Error {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
}
func OutOfRange(subject any, value any) NestedError {
func OutOfRange(subject any, value any) Error {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}
func TypeError(subject any, from, to reflect.Type) NestedError {
func TypeError(subject any, from, to reflect.Type) Error {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}
func TypeError2(subject any, from, to reflect.Value) NestedError {
func TypeError2(subject any, from, to reflect.Value) Error {
return TypeError(subject, from.Type(), to.Type())
}
func TypeMismatch[Expect any](value any) NestedError {
func TypeMismatch[Expect any](value any) Error {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}
func PanicRecv(format string, args ...any) NestedError {
func PanicRecv(format string, args ...any) Error {
return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...))
}

View file

@ -0,0 +1,15 @@
package loadbalancer
import "net/http"
type DummyResponseWriter struct{}
func (w *DummyResponseWriter) Header() (_ http.Header) {
return
}
func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w *DummyResponseWriter) WriteHeader(int) {}

View file

@ -21,7 +21,7 @@ func (lb *LoadBalancer) newIPHash() impl {
if len(lb.Options) == 0 {
return impl
}
var err E.NestedError
var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil {
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err)

View file

@ -1,10 +1,13 @@
package loadbalancer
import (
"context"
"net/http"
"sync"
"time"
"github.com/yusing/go-proxy/internal/common"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task"
@ -54,10 +57,10 @@ func New(cfg *Config) *LoadBalancer {
}
// Start implements task.TaskStarter.
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error {
lb.startTime = time.Now()
lb.task = routeSubtask
lb.task.OnComplete("loadbalancer cleanup", func() {
lb.task.OnFinished("loadbalancer cleanup", func() {
if lb.impl != nil {
lb.pool.RangeAll(func(k string, v *Server) {
lb.impl.OnRemoveServer(v)
@ -69,7 +72,7 @@ func (lb *LoadBalancer) Start(routeSubtask task.Task) E.NestedError {
}
// Finish implements task.TaskFinisher.
func (lb *LoadBalancer) Finish(reason string) {
func (lb *LoadBalancer) Finish(reason any) {
lb.task.Finish(reason)
}
@ -128,7 +131,7 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.rebalance()
lb.impl.OnAddServer(srv)
logger.Infof("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
}
func (lb *LoadBalancer) RemoveServer(srv *Server) {
@ -147,11 +150,11 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
if lb.pool.Size() == 0 {
lb.task.Finish("no server left")
logger.Infof("[remove] loadbalancer %s stopped", lb.Link)
logger.Infof("loadbalancer %s stopped", lb.Link)
return
}
logger.Infof("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
}
func (lb *LoadBalancer) rebalance() {
@ -211,6 +214,21 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return
}
if r.Header.Get(common.HeaderCheckRedirect) != "" {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
defer cancel()
// send dummy request to wake all servers
var dummyRW *DummyResponseWriter
for _, srv := range srvs {
// wake only if server implements Waker
_, ok := srv.handler.(idlewatcher.Waker)
if !ok {
continue
}
wakeReq := r.Clone(ctx)
srv.ServeHTTP(dummyRW, wakeReq)
}
}
lb.impl.ServeHTTP(srvs, rw, r)
}
@ -261,10 +279,9 @@ func (lb *LoadBalancer) String() string {
func (lb *LoadBalancer) availServers() []*Server {
avail := make([]*Server, 0, lb.pool.Size())
lb.pool.RangeAll(func(_ string, srv *Server) {
if srv.Status().Bad() {
return
if srv.Status().Good() {
avail = append(avail, srv)
}
avail = append(avail, srv)
})
return avail
}

View file

@ -14,8 +14,8 @@ func (lb *roundRobin) OnAddServer(srv *Server) {}
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
index := lb.index.Add(1)
srvs[index%uint32(len(srvs))].ServeHTTP(rw, r)
index := lb.index.Add(1) % uint32(len(srvs))
srvs[index].ServeHTTP(rw, r)
if lb.index.Load() >= 2*uint32(len(srvs)) {
lb.index.Store(0)
}

View file

@ -35,7 +35,7 @@ var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
}
}
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
wl := new(cidrWhitelist)
wl.m = &Middleware{
impl: wl,

View file

@ -33,7 +33,7 @@ var CloudflareRealIP = &realIP{
m: &Middleware{withOptions: NewCloudflareRealIP},
}
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.Error) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,

View file

@ -36,7 +36,7 @@ var ForwardAuth = &forwardAuth{
m: &Middleware{withOptions: NewForwardAuthfunc},
}
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, fa.forwardAuthOpts)

View file

@ -11,7 +11,7 @@ import (
)
type (
Error = E.NestedError
Error = E.Error
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
@ -24,7 +24,7 @@ type (
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any
Options any
@ -77,7 +77,7 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if len(optsRaw) != 0 && m.withOptions != nil {
return m.withOptions(optsRaw)
}
@ -108,7 +108,7 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) {
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) {
middlewares = make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
@ -136,7 +136,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
return
}
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) {
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
if err != nil {

View file

@ -10,7 +10,7 @@ import (
"gopkg.in/yaml.v3"
)
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
@ -18,7 +18,7 @@ func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E
return BuildMiddlewaresFromYAML(fileContent)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)

View file

@ -22,7 +22,7 @@ var ModifyRequest = &modifyRequest{
m: &Middleware{withOptions: NewModifyRequest},
}
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest)
var mrFunc RewriteFunc
if common.IsDebug {

View file

@ -24,7 +24,7 @@ var ModifyResponse = &modifyResponse{
m: &Middleware{withOptions: NewModifyResponse},
}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse)
mr.m = &Middleware{impl: mr}
if common.IsDebug {

View file

@ -26,7 +26,7 @@ var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2},
}
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
oauth := new(oAuth2)
oauth.m = &Middleware{
impl: oauth,

View file

@ -41,7 +41,7 @@ var realIPOptsDefault = func() *realIPOpts {
}
}
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
func NewRealIP(opts OptionsRaw) (*Middleware, E.Error) {
riWithOpts := new(realIP)
riWithOpts.m = &Middleware{
impl: riWithOpts,

View file

@ -72,7 +72,7 @@ type testArgs struct {
scheme string
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
var body io.Reader
var rr requestRecorder
var proxyURL *url.URL

View file

@ -9,7 +9,7 @@ import (
type CIDR net.IPNet
func (cidr *CIDR) ConvertFrom(val any) E.NestedError {
func (cidr *CIDR) ConvertFrom(val any) E.Error {
cidrStr, ok := val.(string)
if !ok {
return E.TypeMismatch[string](val)

View file

@ -7,13 +7,7 @@ import (
type Stream interface {
fmt.Stringer
net.Listener
Setup() error
Accept() (conn StreamConn, err error)
Handle(conn StreamConn) error
CloseListeners()
}
type StreamConn interface {
RemoteAddr() net.Addr
Close() error
Handle(conn net.Conn) error
}

View file

@ -1,7 +1,7 @@
package entry
import (
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"
@ -18,7 +18,7 @@ type Entry interface {
IdlewatcherConfig() *idlewatcher.Config
}
func ValidateEntry(m *RawEntry) (Entry, E.NestedError) {
func ValidateEntry(m *RawEntry) (Entry, E.Error) {
m.FillMissingFields()
scheme, err := T.NewScheme(m.Scheme)

View file

@ -5,7 +5,7 @@ import (
"net/url"
"github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"

View file

@ -4,7 +4,7 @@ import (
"fmt"
"github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/config"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer"
net "github.com/yusing/go-proxy/internal/net/types"

View file

@ -7,7 +7,7 @@ import (
E "github.com/yusing/go-proxy/internal/error"
)
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) {
func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.Error) {
h := make(http.Header)
for k, v := range headers {
vSplit := strings.Split(v, ",")

View file

@ -9,6 +9,6 @@ type (
Subdomain = Alias
)
func ValidateHost[String ~string](s String) (Host, E.NestedError) {
func ValidateHost[String ~string](s String) (Host, E.Error) {
return Host(s), nil
}

View file

@ -13,7 +13,7 @@ type (
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
func ValidatePathPattern(s string) (PathPattern, E.NestedError) {
func ValidatePathPattern(s string) (PathPattern, E.Error) {
if len(s) == 0 {
return "", E.Invalid("path", "must not be empty")
}
@ -23,7 +23,7 @@ func ValidatePathPattern(s string) (PathPattern, E.NestedError) {
return PathPattern(s), nil
}
func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) {
func ValidatePathPatterns(s []string) (PathPatterns, E.Error) {
if len(s) == 0 {
return []PathPattern{"/"}, nil
}

View file

@ -8,7 +8,7 @@ import (
type Port int
func ValidatePort[String ~string](v String) (Port, E.NestedError) {
func ValidatePort[String ~string](v String) (Port, E.Error) {
p, err := strconv.Atoi(string(v))
if err != nil {
return ErrPort, E.Invalid("port number", v).With(err)
@ -16,7 +16,7 @@ func ValidatePort[String ~string](v String) (Port, E.NestedError) {
return ValidatePortInt(p)
}
func ValidatePortInt[Int int | uint16](v Int) (Port, E.NestedError) {
func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) {
p := Port(v)
if !p.inBound() {
return ErrPort, E.OutOfRange("port", p)

View file

@ -6,7 +6,7 @@ import (
type Scheme string
func NewScheme[String ~string](s String) (Scheme, E.NestedError) {
func NewScheme[String ~string](s String) (Scheme, E.Error) {
switch s {
case "http", "https", "tcp", "udp":
return Scheme(s), nil

View file

@ -12,7 +12,7 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"`
}
func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) {
func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
split := strings.Split(p, ":")
switch len(split) {
@ -47,7 +47,7 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.NestedError) {
return StreamPort{listeningPort, proxyPort}, nil
}
func parseNameToPort(name string) (Port, E.NestedError) {
func parseNameToPort(name string) (Port, E.Error) {
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return ErrPort, E.Invalid("service", name)

View file

@ -12,7 +12,7 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"`
}
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) {
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) {
ss = &StreamScheme{}
parts := strings.Split(s, ":")
if len(parts) == 1 {

View file

@ -66,7 +66,7 @@ func SetFindMuxDomains(domains []string) {
}
}
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.NestedError) {
func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
var trans *http.Transport
if entry.NoTLSVerify {
@ -97,7 +97,7 @@ func (r *HTTPRoute) String() string {
}
// Start implements task.TaskStarter.
func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
@ -151,7 +151,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
r.addToLoadBalancer()
} else {
httpRoutes.Store(string(r.Alias), r)
r.task.OnComplete("stop rp", func() {
r.task.OnFinished("remove from route table", func() {
httpRoutes.Delete(string(r.Alias))
})
}
@ -160,7 +160,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.NestedError {
}
// Finish implements task.TaskFinisher.
func (r *HTTPRoute) Finish(reason string) {
func (r *HTTPRoute) Finish(reason any) {
r.task.Finish(reason)
}
@ -175,8 +175,8 @@ func (r *HTTPRoute) addToLoadBalancer() {
}
} else {
lb = loadbalancer.New(r.LoadBalance)
lbTask := r.task.Parent().Subtask("loadbalancer %s", r.LoadBalance.Link)
lbTask.OnComplete("remove lb from routes", func() {
lbTask := r.task.Parent().Subtask("loadbalancer " + r.LoadBalance.Link)
lbTask.OnCancel("remove lb from routes", func() {
httpRoutes.Delete(r.LoadBalance.Link)
})
lb.Start(lbTask)
@ -194,9 +194,9 @@ func (r *HTTPRoute) addToLoadBalancer() {
httpRoutes.Store(r.LoadBalance.Link, linked)
}
r.loadBalancer = lb
r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
r.server = loadbalancer.NewServer(r.task.String(), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.HealthMon)
lb.AddServer(r.server)
r.task.OnComplete("remove server from lb", func() {
r.task.OnCancel("remove server from lb", func() {
lb.RemoveServer(r.server)
})
}

View file

@ -25,7 +25,7 @@ var (
AliasRefRegexOld = regexp.MustCompile(`\$\d+`)
)
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) {
func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) {
if dockerHost == common.DockerHostFromEnv {
dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost)
}
@ -40,18 +40,18 @@ func (p *DockerProvider) NewWatcher() W.Watcher {
return W.NewDockerWatcher(p.dockerHost)
}
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) {
routes = R.NewRoutes()
entries := entry.NewProxyEntries()
info, err := D.GetClientInfo(p.dockerHost, true)
containers, err := D.ListContainers(p.dockerHost)
if err != nil {
return routes, E.FailWith("connect to docker", err)
return routes, err
}
errors := E.NewBuilder("errors in docker labels")
for _, c := range info.Containers {
for _, c := range containers {
container := D.FromDocker(&c, p.dockerHost)
if container.IsExcluded {
continue
@ -70,10 +70,6 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) {
})
}
entries.RangeAll(func(_ string, e *entry.RawEntry) {
e.Container.DockerHost = p.dockerHost
})
routes, err = R.FromEntries(entries)
errors.Add(err)
@ -89,7 +85,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool {
// Returns a list of proxy entries for a container.
// Always non-nil.
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.NestedError) {
func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) {
entries = entry.NewProxyEntries()
if p.shouldIgnore(container) {
@ -117,7 +113,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent
return entries, errors.Build().Subject(container.ContainerName)
}
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.NestedError) {
func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) {
b := E.NewBuilder("errors in label %s", key)
defer b.To(&res)

View file

@ -1,7 +1,9 @@
package provider
import (
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher"
@ -32,31 +34,52 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) {
return
}
oldRoutes.RangeAll(func(k string, v *route.Route) {
if !newRoutes.Has(k) {
handler.Remove(v)
if common.IsDebug {
eventsLog := E.NewBuilder("events")
for _, event := range events {
eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID)
}
handler.provider.l.Debug(eventsLog.String())
oldRoutesLog := E.NewBuilder("old routes")
oldRoutes.RangeAll(func(k string, r *route.Route) {
oldRoutesLog.Addf(k)
})
handler.provider.l.Debug(oldRoutesLog.String())
newRoutesLog := E.NewBuilder("new routes")
newRoutes.RangeAll(func(k string, r *route.Route) {
newRoutesLog.Addf(k)
})
handler.provider.l.Debug(newRoutesLog.String())
}
oldRoutes.RangeAll(func(k string, oldr *route.Route) {
newr, ok := newRoutes.Load(k)
if !ok {
handler.Remove(oldr)
} else if handler.matchAny(events, newr) {
handler.Update(parent, oldr, newr)
} else if entry.ShouldNotServe(newr) {
handler.Remove(oldr)
}
})
newRoutes.RangeAll(func(k string, newr *route.Route) {
if oldRoutes.Has(k) {
for _, ev := range events {
if handler.match(ev, newr) {
old, ok := oldRoutes.Load(k)
if !ok { // should not happen
panic("race condition")
}
handler.Update(parent, old, newr)
return
}
}
} else {
if !(oldRoutes.Has(k) || entry.ShouldNotServe(newr)) {
handler.Add(parent, newr)
}
})
}
func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route) bool {
for _, event := range events {
if handler.match(event, route) {
return true
}
}
return false
}
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.t {
switch handler.provider.GetType() {
case ProviderTypeDocker:
return route.Entry.Container.ContainerID == event.ActorID ||
route.Entry.Container.ContainerName == event.ActorName
@ -70,14 +93,15 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool
func (handler *EventHandler) Add(parent task.Task, route *route.Route) {
err := handler.provider.startRoute(parent, route)
if err != nil {
handler.errs.Add(err)
handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err))
} else {
handler.added = append(handler.added, route.Entry.Alias)
}
}
func (handler *EventHandler) Remove(route *route.Route) {
route.Finish("route removal")
route.Finish("route removed")
handler.provider.routes.Delete(route.Entry.Alias)
handler.removed = append(handler.removed, route.Entry.Alias)
}
@ -85,7 +109,7 @@ func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, new
oldRoute.Finish("route update")
err := handler.provider.startRoute(parent, newRoute)
if err != nil {
handler.errs.Add(err)
handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err))
} else {
handler.updated = append(handler.updated, newRoute.Entry.Alias)
}

View file

@ -18,7 +18,7 @@ type FileProvider struct {
path string
}
func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
func FileProviderImpl(filename string) (ProviderImpl, E.Error) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) {
}
}
func Validate(data []byte) E.NestedError {
func Validate(data []byte) E.Error {
return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data)
}
@ -42,7 +42,7 @@ func (p FileProvider) String() string {
return p.fileName
}
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) {
func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) {
routes = R.NewRoutes()
b := E.NewBuilder("validation failure")

View file

@ -7,7 +7,6 @@ import (
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/proxy/entry"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/task"
W "github.com/yusing/go-proxy/internal/watcher"
@ -29,7 +28,7 @@ type (
ProviderImpl interface {
fmt.Stringer
NewWatcher() W.Watcher
LoadRoutesImpl() (R.Routes, E.NestedError)
LoadRoutesImpl() (R.Routes, E.Error)
}
ProviderType string
ProviderStats struct {
@ -43,7 +42,7 @@ const (
ProviderTypeDocker ProviderType = "docker"
ProviderTypeFile ProviderType = "file"
providerEventFlushInterval = 500 * time.Millisecond
providerEventFlushInterval = 300 * time.Millisecond
)
func newProvider(name string, t ProviderType) *Provider {
@ -56,7 +55,7 @@ func newProvider(name string, t ProviderType) *Provider {
return p
}
func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
func NewFileProvider(filename string) (p *Provider, err E.Error) {
name := path.Base(filename)
if name == "" {
return nil, E.Invalid("file name", "empty")
@ -70,7 +69,7 @@ func NewFileProvider(filename string) (p *Provider, err E.NestedError) {
return
}
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.NestedError) {
func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) {
if name == "" {
return nil, E.Invalid("provider name", "empty")
}
@ -101,18 +100,16 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
if entry.UseLoadBalance(r) {
r.Entry.Alias = p.String() + "/" + r.Entry.Alias
}
subtask := parent.Subtask(r.Entry.Alias)
func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error {
subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias)
err := r.Start(subtask)
if err != nil {
p.routes.Delete(r.Entry.Alias)
subtask.Finish(err.String()) // just to ensure
subtask.Finish(err) // just to ensure
return err
} else {
subtask.OnComplete("del from provider", func() {
p.routes.Store(r.Entry.Alias, r)
subtask.OnFinished("del from provider", func() {
p.routes.Delete(r.Entry.Alias)
})
}
@ -120,7 +117,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.NestedError {
}
// Start implements task.TaskStarter.
func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
func (p *Provider) Start(configSubtask task.Task) (res E.Error) {
errors := E.NewBuilder("errors starting routes")
defer errors.To(&res)
@ -141,7 +138,7 @@ func (p *Provider) Start(configSubtask task.Task) (res E.NestedError) {
handler.Log()
flushTask.Finish("events flushed")
},
func(err E.NestedError) {
func(err E.Error) {
p.l.Error(err)
},
)
@ -157,8 +154,8 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) {
return p.routes.Load(alias)
}
func (p *Provider) LoadRoutes() E.NestedError {
var err E.NestedError
func (p *Provider) LoadRoutes() E.Error {
var err E.Error
p.routes, err = p.LoadRoutesImpl()
if p.routes.Size() > 0 {
return err

94
internal/route/raw.go Normal file
View file

@ -0,0 +1,94 @@
package route
import (
"errors"
"fmt"
"net"
"time"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
RawStream struct {
*StreamRoute
listener net.Listener
targetAddr net.Addr
}
)
const (
streamBufferSize = 8192
streamDialTimeout = 5 * time.Second
)
func NewRawStreamRoute(base *StreamRoute) *RawStream {
return &RawStream{
StreamRoute: base,
}
}
func (route *RawStream) Setup() error {
var lcfg net.ListenConfig
var err error
switch route.Scheme.ListeningScheme {
case "tcp":
route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
return err
}
tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
route.listener = tcpListener
case "udp":
route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
return err
}
udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port)
route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener)
default:
return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme))
}
return nil
}
func (route *RawStream) Accept() (net.Conn, error) {
if route.listener == nil {
return nil, errors.New("listener not yet set up")
}
return route.listener.Accept()
}
func (route *RawStream) Handle(c net.Conn) error {
clientConn := c.(net.Conn)
defer clientConn.Close()
route.task.OnCancel("close conn", func() { clientConn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout}
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr)
if err != nil {
return err
}
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
return pipe.Start()
}
func (route *RawStream) Close() error {
return route.listener.Close()
}

View file

@ -44,7 +44,7 @@ func (rt *Route) Container() *docker.Container {
return rt.Entry.Container
}
func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
func NewRoute(raw *entry.RawEntry) (*Route, E.Error) {
en, err := entry.ValidateEntry(raw)
if err != nil {
return nil, err
@ -73,7 +73,7 @@ func NewRoute(raw *entry.RawEntry) (*Route, E.NestedError) {
}, nil
}
func FromEntries(entries entry.RawEntries) (Routes, E.NestedError) {
func FromEntries(entries entry.RawEntries) (Routes, E.Error) {
b := E.NewBuilder("errors in routes")
routes := NewRoutes()

View file

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
stdNet "net"
"sync"
"github.com/sirupsen/logrus"
@ -37,7 +36,7 @@ func GetStreamProxies() F.Map[string, *StreamRoute] {
return streamRoutes
}
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) {
// TODO: support non-coherent scheme
if !entry.Scheme.IsCoherent() {
return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme))
@ -48,16 +47,12 @@ func NewStreamRoute(entry *entry.StreamEntry) (impl, E.NestedError) {
}, nil
}
func (r *StreamRoute) Finish(reason string) {
r.task.Finish(reason)
}
func (r *StreamRoute) String() string {
return fmt.Sprintf("stream %s", r.Alias)
}
// Start implements task.TaskStarter.
func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
func (r *StreamRoute) Start(providerSubtask task.Task) E.Error {
if entry.ShouldNotServe(r) {
providerSubtask.Finish("should not serve")
return nil
@ -71,11 +66,13 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
r.HealthCheck.Disable = true
}
if r.Scheme.ListeningScheme.IsTCP() {
r.Stream = NewTCPRoute(r)
} else {
r.Stream = NewUDPRoute(r)
}
// if r.Scheme.ListeningScheme.IsTCP() {
// r.Stream = NewTCPRoute(r)
// } else {
// r.Stream = NewUDPRoute(r)
// }
r.task = providerSubtask
r.Stream = NewRawStreamRoute(r)
r.l = logrus.WithField("route", r.Stream.String())
switch {
@ -83,6 +80,7 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias))
waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream)
if err != nil {
r.task.Finish(err)
return err
}
r.Stream = waker
@ -90,24 +88,41 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.NestedError {
case entry.UseHealthCheck(r):
r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck)
}
r.task = providerSubtask
r.task.OnComplete("stop stream", r.CloseListeners)
if err := r.Setup(); err != nil {
r.task.Finish(err)
return E.FailWith("setup", err)
}
r.l.Infof("listening on port %d", r.Port.ListeningPort)
go r.acceptConnections()
r.task.OnFinished("close stream", func() {
if err := r.Close(); err != nil {
r.l.Error("close stream error: ", err)
}
})
r.task.OnFinished("remove from route table", func() {
streamRoutes.Delete(string(r.Alias))
})
r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort)
if r.HealthMon != nil {
r.HealthMon.Start(r.task.Subtask("health monitor"))
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn("health monitor error: ", err)
}
}
go r.acceptConnections()
streamRoutes.Store(string(r.Alias), r)
return nil
}
func (r *StreamRoute) Finish(reason any) {
r.task.Finish(reason)
}
func (r *StreamRoute) acceptConnections() {
defer r.task.Finish("listener closed")
for {
select {
case <-r.task.Context().Done():
@ -117,24 +132,17 @@ func (r *StreamRoute) acceptConnections() {
if err != nil {
select {
case <-r.task.Context().Done():
return
default:
var nErr *stdNet.OpError
ok := errors.As(err, &nErr)
if !(ok && nErr.Timeout()) {
r.l.Error("accept connection error: ", err)
r.task.Finish(err.Error())
return
}
continue
r.l.Error("accept connection error: ", err)
r.task.Finish(err)
}
return
}
connTask := r.task.Subtask("%s connection from %s", conn.RemoteAddr().Network(), conn.RemoteAddr().String())
connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr()))
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
r.l.Error(err)
connTask.Finish(err.Error())
} else {
connTask.Finish("connection closed")
}

View file

@ -1,71 +1,68 @@
package route
import (
"context"
"fmt"
"net"
"time"
// import (
// "context"
// "fmt"
// "net"
// "time"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
// "github.com/yusing/go-proxy/internal/net/types"
// T "github.com/yusing/go-proxy/internal/proxy/fields"
// U "github.com/yusing/go-proxy/internal/utils"
// F "github.com/yusing/go-proxy/internal/utils/functional"
// )
const tcpDialTimeout = 5 * time.Second
// const tcpDialTimeout = 5 * time.Second
type (
TCPConnMap = F.Map[net.Conn, struct{}]
TCPRoute struct {
*StreamRoute
listener *net.TCPListener
}
)
// type (
// TCPConnMap = F.Map[net.Conn, struct{}]
// TCPRoute struct {
// *StreamRoute
// listener *net.TCPListener
// }
// )
func NewTCPRoute(base *StreamRoute) *TCPRoute {
return &TCPRoute{StreamRoute: base}
}
// func NewTCPRoute(base *StreamRoute) *TCPRoute {
// return &TCPRoute{StreamRoute: base}
// }
func (route *TCPRoute) Setup() error {
in, err := net.Listen("tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
//! this read the allocated port from original ':0'
route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
route.listener = in.(*net.TCPListener)
return nil
}
// func (route *TCPRoute) Setup() error {
// var cfg net.ListenConfig
// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort))
// if err != nil {
// return err
// }
// //! this read the allocated port from original ':0'
// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port)
// route.listener = in.(*net.TCPListener)
// return nil
// }
func (route *TCPRoute) Accept() (types.StreamConn, error) {
route.listener.SetDeadline(time.Now().Add(time.Second))
return route.listener.Accept()
}
// func (route *TCPRoute) Accept() (types.StreamConn, error) {
// return route.listener.Accept()
// }
func (route *TCPRoute) Handle(c types.StreamConn) error {
clientConn := c.(net.Conn)
// func (route *TCPRoute) Handle(c types.StreamConn) error {
// clientConn := c.(net.Conn)
defer clientConn.Close()
route.task.OnComplete("close conn", func() { clientConn.Close() })
// defer clientConn.Close()
// route.task.OnCancel("close conn", func() { clientConn.Close() })
ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout)
serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
dialer := &net.Dialer{}
// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)
// dialer := &net.Dialer{}
serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
cancel()
if err != nil {
return err
}
// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr)
// cancel()
// if err != nil {
// return err
// }
pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
return pipe.Start()
}
// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn)
// return pipe.Start()
// }
func (route *TCPRoute) CloseListeners() {
if route.listener == nil {
return
}
route.listener.Close()
}
// func (route *TCPRoute) Close() error {
// return route.listener.Close()
// }

View file

@ -1,145 +1,149 @@
package route
import (
"errors"
"fmt"
"io"
"net"
"time"
// import (
// "errors"
// "fmt"
// "io"
// "net"
"github.com/yusing/go-proxy/internal/net/types"
T "github.com/yusing/go-proxy/internal/proxy/fields"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
// "github.com/yusing/go-proxy/internal/net/types"
// T "github.com/yusing/go-proxy/internal/proxy/fields"
// U "github.com/yusing/go-proxy/internal/utils"
// F "github.com/yusing/go-proxy/internal/utils/functional"
// )
type (
UDPRoute struct {
*StreamRoute
// type (
// UDPRoute struct {
// *StreamRoute
connMap UDPConnMap
// connMap UDPConnMap
listeningConn *net.UDPConn
targetAddr *net.UDPAddr
}
UDPConn struct {
key string
src *net.UDPConn
dst *net.UDPConn
U.BidirectionalPipe
}
UDPConnMap = F.Map[string, *UDPConn]
)
// listeningConn net.PacketConn
// targetAddr *net.UDPAddr
// }
// UDPConn struct {
// key string
// src net.Conn
// dst net.Conn
// U.BidirectionalPipe
// }
// UDPConnMap = F.Map[string, *UDPConn]
// )
var NewUDPConnMap = F.NewMap[UDPConnMap]
// var NewUDPConnMap = F.NewMap[UDPConnMap]
const udpBufferSize = 8192
// const udpBufferSize = 8192
func NewUDPRoute(base *StreamRoute) *UDPRoute {
return &UDPRoute{
StreamRoute: base,
connMap: NewUDPConnMap(),
}
}
// func NewUDPRoute(base *StreamRoute) *UDPRoute {
// return &UDPRoute{
// StreamRoute: base,
// connMap: NewUDPConnMap(),
// }
// }
func (route *UDPRoute) Setup() error {
laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
if err != nil {
return err
}
source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr)
if err != nil {
return err
}
raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
if err != nil {
source.Close()
return err
}
// func (route *UDPRoute) Setup() error {
// var cfg net.ListenConfig
// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort))
// if err != nil {
// return err
// }
// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort))
// if err != nil {
// source.Close()
// return err
// }
//! this read the allocated listeningPort from original ':0'
route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
// //! this read the allocated listeningPort from original ':0'
// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port)
route.listeningConn = source
route.targetAddr = raddr
// route.listeningConn = source
// route.targetAddr = raddr
return nil
}
// return nil
// }
func (route *UDPRoute) Accept() (types.StreamConn, error) {
in := route.listeningConn
// func (route *UDPRoute) Accept() (types.StreamConn, error) {
// in := route.listeningConn
buffer := make([]byte, udpBufferSize)
route.listeningConn.SetReadDeadline(time.Now().Add(time.Second))
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
// buffer := make([]byte, udpBufferSize)
// nRead, srcAddr, err := in.ReadFrom(buffer)
// if err != nil {
// return nil, err
// }
if nRead == 0 {
return nil, io.ErrShortBuffer
}
// if nRead == 0 {
// return nil, io.ErrShortBuffer
// }
key := srcAddr.String()
conn, ok := route.connMap.Load(key)
// key := srcAddr.String()
// conn, ok := route.connMap.Load(key)
if !ok {
srcConn, err := net.DialUDP("udp", nil, srcAddr)
if err != nil {
return nil, err
}
dstConn, err := net.DialUDP("udp", nil, route.targetAddr)
if err != nil {
srcConn.Close()
return nil, err
}
conn = &UDPConn{
key,
srcConn,
dstConn,
U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
}
route.connMap.Store(key, conn)
}
// if !ok {
// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String())
// if err != nil {
// return nil, err
// }
// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String())
// if err != nil {
// srcConn.Close()
// return nil, err
// }
// conn = &UDPConn{
// key,
// srcConn,
// dstConn,
// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}),
// }
// route.connMap.Store(key, conn)
// }
_, err = conn.dst.Write(buffer[:nRead])
return conn, err
}
// _, err = conn.dst.Write(buffer[:nRead])
// return conn, err
// }
func (route *UDPRoute) Handle(c types.StreamConn) error {
conn := c.(*UDPConn)
err := conn.Start()
route.connMap.Delete(conn.key)
return err
}
// func (route *UDPRoute) Handle(c types.StreamConn) error {
// switch c := c.(type) {
// case *UDPConn:
// err := c.Start()
// route.connMap.Delete(c.key)
// c.Close()
// return err
// case *net.TCPConn:
// in := route.listeningConn
// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr))
// if err != nil {
// return err
// }
// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start()
// c.Close()
// return err
// }
// return fmt.Errorf("unknown conn type: %T", c)
// }
func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
}
route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) {
if err := conn.Close(); err != nil {
route.l.Errorf("error closing conn: %s", err)
}
})
route.connMap.Clear()
}
// func (route *UDPRoute) Close() error {
// route.connMap.RangeAllParallel(func(k string, v *UDPConn) {
// v.Close()
// })
// route.connMap.Clear()
// return route.listeningConn.Close()
// }
// Close implements types.StreamConn
func (conn *UDPConn) Close() error {
return errors.Join(conn.src.Close(), conn.dst.Close())
}
// // Close implements types.StreamConn
// func (conn *UDPConn) Close() error {
// return errors.Join(conn.src.Close(), conn.dst.Close())
// }
// RemoteAddr implements types.StreamConn
func (conn *UDPConn) RemoteAddr() net.Addr {
return conn.src.RemoteAddr()
}
// // RemoteAddr implements types.StreamConn
// func (conn *UDPConn) RemoteAddr() net.Addr {
// return conn.src.RemoteAddr()
// }
type sourceRWCloser struct {
server *net.UDPConn
*net.UDPConn
}
// type sourceRWCloser struct {
// server net.PacketConn
// net.Conn
// }
func (w sourceRWCloser) Write(p []byte) (int, error) {
return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp
}
// func (w sourceRWCloser) Write(p []byte) (int, error) {
// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr))
// }

View file

@ -0,0 +1,73 @@
package route
import (
"context"
"io"
"net"
"sync"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
UDPListener struct {
ctx context.Context
listener net.PacketConn
connMap UDPConnMap
mu sync.Mutex
}
UDPConnMap = F.Map[string, net.Conn]
)
var NewUDPConnMap = F.NewMap[UDPConnMap]
func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener {
return &UDPListener{
ctx: ctx,
listener: listener,
connMap: NewUDPConnMap(),
}
}
// Addr implements net.Listener.
func (route *UDPListener) Addr() net.Addr {
return route.listener.LocalAddr()
}
func (udpl *UDPListener) Accept() (net.Conn, error) {
in := udpl.listener
buffer := make([]byte, streamBufferSize)
nRead, srcAddr, err := in.ReadFrom(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
udpl.mu.Lock()
defer udpl.mu.Unlock()
key := srcAddr.String()
conn, ok := udpl.connMap.Load(key)
if !ok {
dialer := &net.Dialer{Timeout: streamDialTimeout}
srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String())
if err != nil {
return nil, err
}
udpl.connMap.Store(key, srcConn)
}
return conn, nil
}
// Close implements net.Listener.
func (route *UDPListener) Close() error {
route.connMap.RangeAllParallel(func(key string, conn net.Conn) {
conn.Close()
})
route.connMap.Clear()
return route.listener.Close()
}

View file

@ -116,7 +116,7 @@ func (s *Server) Start() {
}()
}
s.task.OnComplete("stop server", s.stop)
s.task.OnFinished("stop server", s.stop)
}
func (s *Server) stop() {

View file

@ -39,10 +39,11 @@ type (
// Use Task.Finish to stop all subtasks of the task.
Task interface {
TaskFinisher
fmt.Stringer
// Name returns the name of the task.
Name() string
// Context returns the context associated with the task. This context is
// canceled when Finish is called.
// canceled when Finish of the task is called, or parent task is canceled.
Context() context.Context
// FinishCause returns the reason / error that caused the task to be finished.
FinishCause() error
@ -53,12 +54,16 @@ type (
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
//
// This should not be called after Finish, Wait, or WaitSubTasks is called.
Subtask(usageFmt string, args ...any) Task
// OnComplete calls fn when the task and all subtasks are finished.
Subtask(name string) Task
// OnFinished calls fn when all subtasks are finished.
//
// It cannot be called after Finish or Wait is called.
OnComplete(about string, fn func())
// Wait waits for all subtasks, itself and all OnComplete to finish.
OnFinished(about string, fn func())
// OnCancel calls fn when the task is canceled.
//
// It cannot be called after Finish or Wait is called.
OnCancel(about string, fn func())
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
//
// It must be called only after Finish is called.
Wait()
@ -76,37 +81,46 @@ type (
// The task passed must be a subtask of the caller task.
//
// callerSubtask.Finish must be called when start fails or the object is finished.
Start(callerSubtask Task) E.NestedError
Start(callerSubtask Task) E.Error
}
TaskFinisher interface {
// Finish marks the task as finished by cancelling its context.
// Finish marks the task as finished and cancel its context.
//
// Then call Wait to wait for all subtasks and OnComplete of the task to finish.
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
// of the task to finish.
//
// Note that it will also cancel all subtasks.
Finish(reason string)
Finish(reason any)
}
task struct {
ctx context.Context
cancel context.CancelCauseFunc
parent *task
subtasks *xsync.MapOf[*task, struct{}]
parent *task
subtasks *xsync.MapOf[*task, struct{}]
subTasksWg sync.WaitGroup
name, line string
subTasksWg, onCompleteWg sync.WaitGroup
OnFinishedFuncs []func()
OnFinishedMu sync.Mutex
onFinishedWg sync.WaitGroup
finishOnce sync.Once
}
)
var (
ErrProgramExiting = errors.New("program exiting")
ErrTaskCancelled = errors.New("task cancelled")
ErrTaskCanceled = errors.New("task canceled")
)
// GlobalTask returns a new Task with the given name, derived from the global context.
func GlobalTask(format string, args ...any) Task {
return globalTask.Subtask(format, args...)
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
return globalTask.Subtask(format)
}
// DebugTaskMap returns a map[string]any representation of the global task tree.
@ -155,6 +169,10 @@ func (t *task) Name() string {
return t.name
}
func (t *task) String() string {
return t.name
}
func (t *task) Context() context.Context {
return t.ctx
}
@ -171,43 +189,83 @@ func (t *task) Parent() Task {
return t.parent
}
func (t *task) OnComplete(about string, fn func()) {
t.onCompleteWg.Add(1)
func (t *task) runAllOnFinished(onCompTask Task) {
<-t.ctx.Done()
t.WaitSubTasks()
for _, OnFinishedFunc := range t.OnFinishedFuncs {
OnFinishedFunc()
t.onFinishedWg.Done()
}
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
}
func (t *task) OnFinished(about string, fn func()) {
if t.parent == globalTask {
t.OnCancel(about, fn)
return
}
t.onFinishedWg.Add(1)
t.OnFinishedMu.Lock()
defer t.OnFinishedMu.Unlock()
if t.OnFinishedFuncs == nil {
onCompTask := GlobalTask(t.name + " > OnFinished")
go t.runAllOnFinished(onCompTask)
}
var file string
var line int
if common.IsTrace {
_, file, line, _ = runtime.Caller(1)
}
go func() {
idx := len(t.OnFinishedFuncs)
wrapped := func() {
defer func() {
if err := recover(); err != nil {
logrus.Errorf("panic in task %q\nline %s:%d\n%v", t.name, file, line, err)
logrus.Errorf("panic in %s > OnFinished[%d]: %q\nline %s:%d\n%v", t.name, idx, about, file, line, err)
}
}()
defer t.onCompleteWg.Done()
t.subTasksWg.Wait()
fn()
logrus.Tracef("line %s:%d\n%s > OnFinished[%d] done: %s", file, line, t.name, idx, about)
}
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
}
func (t *task) OnCancel(about string, fn func()) {
onCompTask := GlobalTask(t.name + " > OnFinished")
go func() {
<-t.ctx.Done()
fn()
logrus.Tracef("line %s:%d\ntask %q -> %q done", file, line, t.name, about)
t.cancel(nil) // ensure resources are released
onCompTask.Finish("done")
logrus.Tracef("%s > onCancel done: %s", t.name, about)
}()
}
func (t *task) Finish(reason string) {
t.cancel(fmt.Errorf("%w: %s, reason: %s", ErrTaskCancelled, t.name, reason))
t.Wait()
func (t *task) Finish(reason any) {
var format string
switch reason.(type) {
case error:
format = "%w"
case string, fmt.Stringer:
format = "%s"
default:
format = "%v"
}
t.finishOnce.Do(func() {
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
t.Wait()
})
}
func (t *task) Subtask(format string, args ...any) Task {
if len(args) > 0 {
format = fmt.Sprintf(format, args...)
}
func (t *task) Subtask(name string) Task {
ctx, cancel := context.WithCancelCause(t.ctx)
return t.newSubTask(ctx, cancel, format)
return t.newSubTask(ctx, cancel, name)
}
func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task {
parent := t
if common.IsTrace {
name = parent.name + " > " + name
}
subtask := &task{
ctx: ctx,
cancel: cancel,
@ -222,10 +280,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
if ok {
subtask.line = fmt.Sprintf("%s:%d", file, line)
}
logrus.Tracef("line %s\ntask %q started", subtask.line, name)
logrus.Tracef("line %s\n%s started", subtask.line, name)
go func() {
subtask.Wait()
logrus.Tracef("task %q finished", subtask.Name())
logrus.Tracef("%s finished: %s", subtask.Name(), subtask.FinishCause())
}()
}
go func() {
@ -237,11 +295,9 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n
}
func (t *task) Wait() {
t.subTasksWg.Wait()
if t != globalTask {
<-t.ctx.Done()
}
t.onCompleteWg.Wait()
<-t.ctx.Done()
t.WaitSubTasks()
t.onFinishedWg.Wait()
}
func (t *task) WaitSubTasks() {
@ -270,9 +326,9 @@ func (t *task) tree(prefix ...string) string {
}
if t.line != "" {
sb.WriteString("line " + t.line + "\n")
}
if len(pre) > 0 {
sb.WriteString(pre + "- ")
if len(pre) > 0 {
sb.WriteString(pre + "- ")
}
}
sb.WriteString(t.Name() + "\n")
t.subtasks.Range(func(subtask *task, _ struct{}) bool {
@ -299,7 +355,8 @@ func (t *task) tree(prefix ...string) string {
// only.
func (t *task) serialize() map[string]any {
m := make(map[string]any)
m["name"] = t.name
parts := strings.Split(t.name, ">")
m["name"] = strings.TrimSpace(parts[len(parts)-1])
if t.line != "" {
m["line"] = t.line
}

View file

@ -40,7 +40,7 @@ func TestTaskCancellation(t *testing.T) {
err := subTask.Context().Err()
ExpectError(t, context.Canceled, err)
cause := context.Cause(subTask.Context())
ExpectError(t, ErrTaskCancelled, cause)
ExpectError(t, ErrTaskCanceled, cause)
case <-time.After(1 * time.Second):
t.Fatal("subTask context was not canceled as expected")
}
@ -74,7 +74,7 @@ func TestOnComplete(t *testing.T) {
task := GlobalTask("test")
var value atomic.Int32
task.OnComplete("set value", func() {
task.OnFinished("set value", func() {
value.Store(1234)
})
task.Finish("done")
@ -90,10 +90,10 @@ func TestGlobalContextWait(t *testing.T) {
subTask1 := rootTask.Subtask("subtask1")
subTask2 := rootTask.Subtask("subtask2")
subTask1.OnComplete("set finished", func() {
subTask1.OnFinished("set finished", func() {
finished1 = true
})
subTask2.OnComplete("set finished", func() {
subTask2.OnFinished("set finished", func() {
finished2 = true
})
@ -117,8 +117,8 @@ func TestGlobalContextWait(t *testing.T) {
ExpectTrue(t, finished1)
ExpectTrue(t, finished2)
ExpectError(t, context.Canceled, rootTask.Context().Err())
ExpectError(t, ErrTaskCancelled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCancelled, context.Cause(subTask2.Context()))
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
}
func TestTimeoutOnGlobalContextWait(t *testing.T) {

View file

@ -160,7 +160,7 @@ func (m Map[KT, VT]) Has(k KT) bool {
// Returns:
//
// error: if the unmarshaling fails
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError {
func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.Error {
if m.Size() != 0 {
return E.FailedWhy("unmarshal from yaml", "map is not empty")
}

View file

@ -152,7 +152,7 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error {
return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src})
}
func LoadJSON[T any](path string, pointer *T) E.NestedError {
func LoadJSON[T any](path string, pointer *T) E.Error {
data, err := E.Check(os.ReadFile(path))
if err.HasError() {
return err
@ -160,7 +160,7 @@ func LoadJSON[T any](path string, pointer *T) E.NestedError {
return E.From(json.Unmarshal(data, pointer))
}
func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.NestedError {
func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.Error {
data, err := E.Check(json.Marshal(pointer))
if err.HasError() {
return err

View file

@ -19,11 +19,11 @@ import (
type (
SerializedObject = map[string]any
Converter interface {
ConvertFrom(value any) E.NestedError
ConvertFrom(value any) E.Error
}
)
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error {
var i any
err := yaml.Unmarshal(data, &i)
@ -66,7 +66,7 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
// Returns:
// - result: The resulting map[string]any representation of the data.
// - error: An error if the data type is unsupported or if there is an error during conversion.
func Serialize(data any) (SerializedObject, E.NestedError) {
func Serialize(data any) (SerializedObject, E.Error) {
result := make(map[string]any)
// Use reflection to inspect the data type
@ -137,7 +137,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.NestedError {
func Deserialize(src SerializedObject, dst any) E.Error {
if src == nil {
return E.Invalid("src", "nil")
}
@ -210,7 +210,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
//
// Returns:
// - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
func Convert(src reflect.Value, dst reflect.Value) E.Error {
srcT := src.Type()
dstT := dst.Type()
@ -277,7 +277,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
return converter.ConvertFrom(src.Interface())
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) {
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) {
convertible = true
if dst.Kind() == reflect.Ptr {
if dst.IsNil() {
@ -379,7 +379,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N
return true, Convert(reflect.ValueOf(tmp), dst)
}
func DeserializeJSON(j map[string]string, target any) E.NestedError {
func DeserializeJSON(j map[string]string, target any) E.Error {
data, err := E.Check(json.Marshal(j))
if err != nil {
return err

View file

@ -113,7 +113,7 @@ type testType struct {
bar string
}
func (c *testType) ConvertFrom(v any) E.NestedError {
func (c *testType) ConvertFrom(v any) E.Error {
switch v := v.(type) {
case string:
c.bar = v

View file

@ -98,7 +98,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
return got.(T)
}
func Must[T any](v T, err E.NestedError) T {
func Must[T any](v T, err E.Error) T {
if err != nil {
panic(err)
}

View file

@ -21,7 +21,7 @@ type DirWatcher struct {
mu sync.Mutex
eventCh chan Event
errCh chan E.NestedError
errCh chan E.Error
ctx context.Context
}
@ -48,14 +48,14 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher {
w: w,
fwMap: F.NewMapOf[string, *fileWatcher](),
eventCh: make(chan Event),
errCh: make(chan E.NestedError),
errCh: make(chan E.Error),
ctx: ctx,
}
go helper.start()
return helper
}
func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) {
func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.Error) {
return h.eventCh, h.errCh
}
@ -71,7 +71,7 @@ func (h *DirWatcher) Add(relPath string) Watcher {
s = &fileWatcher{
relPath: relPath,
eventCh: make(chan Event),
errCh: make(chan E.NestedError),
errCh: make(chan E.Error),
}
go func() {
defer func() {

View file

@ -36,6 +36,14 @@ var (
NewDockerFilter = filters.NewArgs
optionsDefault = DockerListOptions{Filters: NewDockerFilter(
DockerFilterContainer,
DockerFilterStart,
// DockerFilterStop,
DockerFilterDie,
DockerFilterDestroy,
)}
dockerWatcherRetryInterval = 3 * time.Second
)
@ -61,13 +69,13 @@ func NewDockerWatcherWithClient(client D.Client) DockerWatcher {
WithField("host", client.DaemonHost()))}
}
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
return w.EventsWithOptions(ctx, optionsWatchAll)
func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) {
return w.EventsWithOptions(ctx, optionsDefault)
}
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) {
func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.Error) {
eventCh := make(chan Event)
errCh := make(chan E.NestedError)
errCh := make(chan E.Error)
go func() {
defer close(eventCh)
@ -80,7 +88,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
}()
if !w.client.Connected() {
var err E.NestedError
var err E.Error
attempts := 0
for {
w.client, err = D.ConnectClient(w.host)
@ -141,11 +149,3 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
return eventCh, errCh
}
var optionsWatchAll = DockerListOptions{Filters: NewDockerFilter(
DockerFilterContainer,
DockerFilterStart,
// DockerFilterStop,
DockerFilterDie,
DockerFilterDestroy,
)}

View file

@ -3,20 +3,22 @@ package events
import (
"time"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task"
)
type (
EventQueue struct {
task task.Task
queue []Event
ticker *time.Ticker
onFlush OnFlushFunc
onError OnErrorFunc
task task.Task
queue []Event
ticker *time.Ticker
flushInterval time.Duration
onFlush OnFlushFunc
onError OnErrorFunc
}
OnFlushFunc = func(flushTask task.Task, events []Event)
OnErrorFunc = func(err E.NestedError)
OnErrorFunc = func(err E.Error)
)
const eventQueueCapacity = 10
@ -35,40 +37,45 @@ const eventQueueCapacity = 10
// flushTask.Finish must be called after the flush is done,
// but the onFlush function can return earlier (e.g. run in another goroutine).
//
// If task is cancelled before the flushInterval is reached, the events in queue will be discarded.
// If task is canceled before the flushInterval is reached, the events in queue will be discarded.
func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue {
return &EventQueue{
task: parent.Subtask("event queue"),
queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval),
onFlush: onFlush,
onError: onError,
task: parent.Subtask("event queue"),
queue: make([]Event, 0, eventQueueCapacity),
ticker: time.NewTicker(flushInterval),
flushInterval: flushInterval,
onFlush: onFlush,
onError: onError,
}
}
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.NestedError) {
func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) {
go func() {
defer e.ticker.Stop()
for {
select {
case <-e.task.Context().Done():
e.task.Finish(e.task.FinishCause().Error())
return
case <-e.ticker.C:
if len(e.queue) > 0 {
flushTask := e.task.Subtask("flush events")
queue := e.queue
e.queue = make([]Event, 0, eventQueueCapacity)
go func() {
defer func() {
if err := recover(); err != nil {
e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name()))
}
if !common.IsDebug {
go func() {
defer func() {
if err := recover(); err != nil {
e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name()))
}
}()
e.onFlush(flushTask, queue)
}()
e.onFlush(flushTask, queue)
}()
} else {
go e.onFlush(flushTask, queue)
}
flushTask.Wait()
}
e.ticker.Reset(e.flushInterval)
case event, ok := <-eventCh:
e.queue = append(e.queue, event)
if !ok {

View file

@ -9,9 +9,9 @@ import (
type fileWatcher struct {
relPath string
eventCh chan Event
errCh chan E.NestedError
errCh chan E.Error
}
func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) {
func (fw *fileWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) {
return fw.eventCh, fw.errCh
}

View file

@ -0,0 +1,28 @@
package health
import (
"encoding/json"
"fmt"
"time"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/task"
)
type (
HealthMonitor interface {
task.TaskStarter
task.TaskFinisher
fmt.Stringer
json.Marshaler
Status() Status
Uptime() time.Duration
Name() string
}
HealthChecker interface {
CheckHealth() (healthy bool, detail string, err error)
URL() types.URL
Config() *HealthCheckConfig
UpdateURL(url types.URL)
}
)

View file

@ -2,9 +2,7 @@ package health
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
E "github.com/yusing/go-proxy/internal/error"
@ -15,21 +13,6 @@ import (
)
type (
HealthMonitor interface {
task.TaskStarter
task.TaskFinisher
fmt.Stringer
json.Marshaler
Status() Status
Uptime() time.Duration
Name() string
}
HealthChecker interface {
CheckHealth() (healthy bool, detail string, err error)
URL() types.URL
Config() *HealthCheckConfig
UpdateURL(url types.URL)
}
HealthCheckFunc func() (healthy bool, detail string, err error)
monitor struct {
service string
@ -71,7 +54,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance
}
// Start implements task.TaskStarter.
func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
func (mon *monitor) Start(routeSubtask task.Task) E.Error {
mon.service = routeSubtask.Parent().Name()
mon.task = routeSubtask
@ -84,7 +67,6 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
if mon.status.Load() != StatusError {
mon.status.Store(StatusUnknown)
}
mon.task.Finish(mon.task.FinishCause().Error())
}()
if err := mon.checkUpdateHealth(); err != nil {
@ -115,7 +97,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.NestedError {
}
// Finish implements task.TaskFinisher.
func (mon *monitor) Finish(reason string) {
func (mon *monitor) Finish(reason any) {
mon.task.Finish(reason)
}
@ -169,10 +151,10 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
}).MarshalJSON()
}
func (mon *monitor) checkUpdateHealth() E.NestedError {
func (mon *monitor) checkUpdateHealth() E.Error {
healthy, detail, err := mon.checkHealth()
if err != nil {
defer mon.task.Finish(err.Error())
defer mon.task.Finish(err)
mon.status.Store(StatusError)
if !errors.Is(err, context.Canceled) {
return E.Failure("check health").With(err)

View file

@ -10,5 +10,5 @@ import (
type Event = events.Event
type Watcher interface {
Events(ctx context.Context) (<-chan Event, <-chan E.NestedError)
Events(ctx context.Context) (<-chan Event, <-chan E.Error)
}