migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher

This commit is contained in:
yusing 2024-10-29 11:34:58 +08:00
parent cfa74d69ae
commit e5bbb18414
137 changed files with 2640 additions and 2348 deletions

View file

@ -65,3 +65,6 @@ debug-list-containers:
ci-test: ci-test:
mkdir -p /tmp/artifacts mkdir -p /tmp/artifacts
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
cloc:
cloc --not-match-f '_test.go$$' cmd internal pkg

View file

@ -2,7 +2,6 @@ package main
import ( import (
"encoding/json" "encoding/json"
"io"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -10,15 +9,14 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/api"
"github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/notif"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -33,44 +31,26 @@ func main() {
return return
} }
l := logrus.WithField("module", "main")
timeFmt := "01-02 15:04:05"
fullTS := true
if common.IsTrace {
logrus.SetLevel(logrus.TraceLevel)
timeFmt = "04:05"
fullTS = false
} else if common.IsDebug {
logrus.SetLevel(logrus.DebugLevel)
}
if args.Command != common.CommandStart {
logrus.SetOutput(io.Discard)
} else {
logrus.SetFormatter(&logrus.TextFormatter{
DisableSorting: true,
FullTimestamp: fullTS,
ForceColors: true,
TimestampFormat: timeFmt,
})
logrus.Infof("go-proxy version %s", pkg.GetVersion())
logrus.AddHook(notif.GetDispatcher())
}
if args.Command == common.CommandReload { if args.Command == common.CommandReload {
if err := query.ReloadServer(); err != nil { if err := query.ReloadServer(); err != nil {
log.Fatal(err) E.LogFatal("server reload error", err)
} }
log.Print("ok") logging.Info().Msg("ok")
return return
} }
// exit if only validate config if args.Command == common.CommandStart {
logging.Info().Msgf("go-proxy version %s", pkg.GetVersion())
logging.Trace().Msg("trace enabled")
// logging.AddHook(notif.GetDispatcher())
} else {
logging.DiscardLogger()
}
if args.Command == common.CommandValidate { if args.Command == common.CommandValidate {
data, err := os.ReadFile(common.ConfigPath) data, err := os.ReadFile(common.ConfigPath)
if err == nil { if err == nil {
err = config.Validate(data).Error() err = config.Validate(data)
} }
if err != nil { if err != nil {
log.Fatal("config error: ", err) log.Fatal("config error: ", err)
@ -88,7 +68,7 @@ func main() {
var cfg *config.Config var cfg *config.Config
var err E.Error var err E.Error
if cfg, err = config.Load(); err != nil { if cfg, err = config.Load(); err != nil {
logrus.Warn(err) E.LogWarn("errors in config", err)
} }
switch args.Command { switch args.Command {
@ -145,10 +125,10 @@ func main() {
autocert := config.GetAutoCertProvider() autocert := config.GetAutoCertProvider()
if autocert != nil { if autocert != nil {
if err := autocert.Setup(); err != nil { if err := autocert.Setup(); err != nil {
l.Fatal(err) E.LogFatal("autocert setup error", err)
} }
} else { } else {
l.Info("autocert not configured") logging.Info().Msg("autocert not configured")
} }
proxyServer := server.InitProxyServer(server.Options{ proxyServer := server.InitProxyServer(server.Options{
@ -174,7 +154,7 @@ func main() {
<-sig <-sig
// grafully shutdown // grafully shutdown
logrus.Info("shutting down") logging.Info().Msg("shutting down")
task.CancelGlobalContext() task.CancelGlobalContext()
task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown)) task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown))
} }
@ -182,15 +162,15 @@ func main() {
func prepareDirectory(dir string) { func prepareDirectory(dir string) {
if _, err := os.Stat(dir); os.IsNotExist(err) { if _, err := os.Stat(dir); os.IsNotExist(err) {
if err = os.MkdirAll(dir, 0o755); err != nil { if err = os.MkdirAll(dir, 0o755); err != nil {
logrus.Fatalf("failed to create directory %s: %v", dir, err) logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
} }
} }
} }
func printJSON(obj any) { func printJSON(obj any) {
j, err := E.Check(json.MarshalIndent(obj, "", " ")) j, err := json.MarshalIndent(obj, "", " ")
if err != nil { if err != nil {
logrus.Fatal(err) logging.Fatal().Err(err).Send()
} }
rawLogger := log.New(os.Stdout, "", 0) rawLogger := log.New(os.Stdout, "", 0)
rawLogger.Printf("%s", j) // raw output for convenience using "jq" rawLogger.Printf("%s", j) // raw output for convenience using "jq"

7
go.mod
View file

@ -10,8 +10,8 @@ require (
github.com/go-acme/lego/v4 v4.19.2 github.com/go-acme/lego/v4 v4.19.2
github.com/gotify/server/v2 v2.5.0 github.com/gotify/server/v2 v2.5.0
github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/puzpuzpuz/xsync/v3 v3.4.0
github.com/rs/zerolog v1.33.0
github.com/santhosh-tekuri/jsonschema v1.2.4 github.com/santhosh-tekuri/jsonschema v1.2.4
github.com/sirupsen/logrus v1.9.3
golang.org/x/net v0.30.0 golang.org/x/net v0.30.0
golang.org/x/text v0.19.0 golang.org/x/text v0.19.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
@ -20,7 +20,7 @@ require (
require ( require (
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cloudflare/cloudflare-go v0.107.0 // indirect github.com/cloudflare/cloudflare-go v0.108.0 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-connections v0.5.0 // indirect
@ -33,6 +33,8 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/miekg/dns v1.1.62 // indirect github.com/miekg/dns v1.1.62 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/term v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect
@ -42,6 +44,7 @@ require (
github.com/ovh/go-ovh v1.6.0 // indirect github.com/ovh/go-ovh v1.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.31.0 // indirect go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect

18
go.sum
View file

@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc= github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0=
github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc= github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -40,6 +41,7 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@ -61,6 +63,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g=
github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
@ -88,6 +96,9 @@ github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPK
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis= github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4= github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
@ -140,6 +151,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -6,7 +6,6 @@ import (
"net/http" "net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1" v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
. "github.com/yusing/go-proxy/internal/api/v1/utils" . "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
@ -38,7 +37,6 @@ func NewHandler() http.Handler {
mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent)
mux.HandleFunc("GET", "/v1/stats", v1.Stats) mux.HandleFunc("GET", "/v1/stats", v1.Stats)
mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS)
mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc())
return mux return mux
} }
@ -50,7 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
host, _, _ := net.SplitHostPort(r.RemoteAddr) host, _, _ := net.SplitHostPort(r.RemoteAddr)
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" { if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
Logger.Warnf("blocked API request from %s", host) LogWarn(r).Msgf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden) http.Error(w, "forbidden", http.StatusForbidden)
return return
} }

View file

@ -1,31 +0,0 @@
package errorpage
import (
"net/http"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
)
func GetHandleFunc() http.HandlerFunc {
setup()
return serveHTTP
}
func serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
content, ok := fileContentMap.Load(r.URL.Path)
if !ok {
http.Error(w, "404 not found", http.StatusNotFound)
return
}
if _, err := w.Write(content); err != nil {
HandleErr(w, r, err)
}
}

View file

@ -39,15 +39,16 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) {
return return
} }
var validateErr E.Error var valErr E.Error
if filename == common.ConfigFileName { if filename == common.ConfigFileName {
validateErr = config.Validate(content) valErr = config.Validate(content)
} else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) { } else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) {
validateErr = provider.Validate(content) valErr = provider.Validate(content)
} }
// no validation for include files
if validateErr != nil { if valErr != nil {
U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest) U.RespondJSON(w, r, valErr, http.StatusBadRequest)
return return
} }

View file

@ -20,16 +20,13 @@ func ReloadServer() E.Error {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode) failure := E.Errorf("server reload status %v", resp.StatusCode)
b, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return failure.Extraf("unable to read response body: %s", err) return failure.With(err)
} }
reloadErr, ok := E.FromJSON(b) reloadErr := string(body)
if ok { return failure.Withf(reloadErr)
return E.Join("reload success, but server returned error", reloadErr)
}
return failure.Extraf("unable to read response body")
} }
return nil return nil
} }
@ -42,7 +39,7 @@ func List[T any](what string) (_ T, outErr E.Error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode) outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode)
return return
} }
var res T var res T

View file

@ -9,8 +9,8 @@ import (
func Reload(w http.ResponseWriter, r *http.Request) { func Reload(w http.ResponseWriter, r *http.Request) {
if err := config.Reload(); err != nil { if err := config.Reload(); err != nil {
U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError) U.HandleErr(w, r, err)
} else { return
w.WriteHeader(http.StatusOK)
} }
U.WriteBody(w, []byte("OK"))
} }

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/server"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
func Stats(w http.ResponseWriter, r *http.Request) { func Stats(w http.ResponseWriter, r *http.Request) {
@ -23,7 +23,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses)) originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses))
if len(originPats) == 0 { if len(originPats) == 0 {
U.Logger.Warnf("no match domains configured, accepting websocket request from all origins") U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
originPats = []string{"*"} originPats = []string{"*"}
} else { } else {
for i, domain := range config.Value().MatchDomains { for i, domain := range config.Value().MatchDomains {
@ -38,7 +38,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
OriginPatterns: originPats, OriginPatterns: originPats,
}) })
if err != nil { if err != nil {
U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err) U.LogError(r).Err(err).Msg("failed to upgrade websocket")
return return
} }
/* trunk-ignore(golangci-lint/errcheck) */ /* trunk-ignore(golangci-lint/errcheck) */
@ -53,7 +53,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
for range ticker.C { for range ticker.C {
stats := getStats() stats := getStats()
if err := wsjson.Write(ctx, conn, stats); err != nil { if err := wsjson.Write(ctx, conn, stats); err != nil {
U.Logger.Errorf("/stats/ws failed to write JSON: %s", err) U.LogError(r).Msg("failed to write JSON")
return return
} }
} }
@ -62,6 +62,6 @@ func StatsWS(w http.ResponseWriter, r *http.Request) {
func getStats() map[string]any { func getStats() map[string]any {
return map[string]any{ return map[string]any{
"proxies": config.Statistics(), "proxies": config.Statistics(),
"uptime": utils.FormatDuration(server.GetProxyServer().Uptime()), "uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()),
} }
} }

View file

@ -1,37 +1,31 @@
package utils package utils
import ( import (
"errors"
"fmt"
"net/http" "net/http"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
var Logger = logrus.WithField("module", "api")
func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) { func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) {
if origErr == nil { if origErr == nil {
return return
} }
err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL) LogError(r).Msg(origErr.Error())
Logger.Error(err)
if len(code) > 0 { if len(code) > 0 {
http.Error(w, err.String(), code[0]) http.Error(w, origErr.Error(), code[0])
return return
} }
http.Error(w, err.String(), http.StatusInternalServerError) http.Error(w, origErr.Error(), http.StatusInternalServerError)
} }
func ErrMissingKey(k string) error { func ErrMissingKey(k string) error {
return errors.New("missing key '" + k + "' in query or request body") return E.New("missing key '" + k + "' in query or request body")
} }
func ErrInvalidKey(k string) error { func ErrInvalidKey(k string) error {
return errors.New("invalid key '" + k + "' in query or request body") return E.New("invalid key '" + k + "' in query or request body")
} }
func ErrNotFound(k, v string) error { func ErrNotFound(k, v string) error {
return fmt.Errorf("key %q with value %q not found", k, v) return E.Errorf("key %q with value %q not found", k, v)
} }

View file

@ -9,12 +9,11 @@ import (
) )
var ( var (
HTTPClient = &http.Client{ httpClient = &http.Client{
Timeout: common.ConnectionTimeout, Timeout: common.ConnectionTimeout,
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DisableKeepAlives: true, DisableKeepAlives: true,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: false,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: common.DialTimeout, Timeout: common.DialTimeout,
KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives
@ -23,7 +22,7 @@ var (
}, },
} }
Get = HTTPClient.Get Get = httpClient.Get
Post = HTTPClient.Post Post = httpClient.Post
Head = HTTPClient.Head Head = httpClient.Head
) )

View file

@ -0,0 +1,18 @@
package utils
import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).Str("module", "api").
Str("method", r.Method).
Str("path", r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) }
func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) }

View file

@ -3,6 +3,8 @@ package utils
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/logging"
) )
func WriteBody(w http.ResponseWriter, body []byte) { func WriteBody(w http.ResponseWriter, body []byte) {
@ -11,15 +13,25 @@ func WriteBody(w http.ResponseWriter, body []byte) {
} }
} }
func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool { func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) {
if len(code) > 0 { if len(code) > 0 {
w.WriteHeader(code[0]) w.WriteHeader(code[0])
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
j, err := json.MarshalIndent(data, "", " ") var j []byte
if err != nil { var err error
HandleErr(w, r, err)
return false switch data := data.(type) {
case string:
j = []byte(`"` + data + `"`)
case []byte:
j = data
default:
j, err = json.MarshalIndent(data, "", " ")
if err != nil {
logging.Panic().Err(err).Msg("failed to marshal json")
return false
}
} }
_, err = w.Write(j) _, err = w.Write(j)
if err != nil { if err != nil {

View file

@ -8,12 +8,21 @@ import (
"github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/config/types"
) )
type Config types.AutoCertConfig type Config types.AutoCertConfig
var (
ErrMissingDomain = E.New("missing field 'domains'")
ErrMissingEmail = E.New("missing field 'email'")
ErrMissingProvider = E.New("missing field 'provider'")
ErrUnknownProvider = E.New("unknown provider")
)
func NewConfig(cfg *types.AutoCertConfig) *Config { func NewConfig(cfg *types.AutoCertConfig) *Config {
if cfg.CertPath == "" { if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault cfg.CertPath = CertFileDefault
@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config {
return (*Config)(cfg) return (*Config)(cfg)
} }
func (cfg *Config) GetProvider() (provider *Provider, res E.Error) { func (cfg *Config) GetProvider() (*Provider, E.Error) {
b := E.NewBuilder("unable to initialize autocert") b := E.NewBuilder("autocert errors")
defer b.To(&res)
if cfg.Provider != ProviderLocal { if cfg.Provider != ProviderLocal {
if len(cfg.Domains) == 0 { if len(cfg.Domains) == 0 {
b.Addf("%s", "no domains specified") b.Add(ErrMissingDomain)
} }
if cfg.Provider == "" { if cfg.Provider == "" {
b.Addf("%s", "no provider specified") b.Add(ErrMissingProvider)
} }
if cfg.Email == "" { if cfg.Email == "" {
b.Addf("%s", "no email specified") b.Add(ErrMissingEmail)
} }
// check if provider is implemented // check if provider is implemented
_, ok := providersGenMap[cfg.Provider] _, ok := providersGenMap[cfg.Provider]
if !ok { if !ok {
b.Addf("unknown provider: %q", cfg.Provider) b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap))))
} }
} }
if b.HasError() { if b.HasError() {
return return nil, b.Error()
} }
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err.HasError() { if err != nil {
b.Add(E.FailWith("generate private key", err)) b.Addf("generate private key: %w", err)
return return nil, b.Error()
} }
user := &User{ user := &User{
@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.Error) {
legoCfg := lego.NewConfig(user) legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048 legoCfg.Certificate.KeyType = certcrypto.RSA2048
provider = &Provider{ return &Provider{
cfg: cfg, cfg: cfg,
user: user, user: user,
legoCfg: legoCfg, legoCfg: legoCfg,
} }, nil
return
} }

View file

@ -7,7 +7,6 @@ import (
"github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/duckdns" "github.com/go-acme/lego/v4/providers/dns/duckdns"
"github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/sirupsen/logrus"
) )
const ( const (
@ -36,5 +35,3 @@ var providersGenMap = map[string]ProviderGenerator{
var ( var (
ErrGetCertFailure = errors.New("get certificate failed") ErrGetCertFailure = errors.New("get certificate failed")
) )
var logger = logrus.WithField("module", "autocert")

View file

@ -0,0 +1,5 @@
package autocert
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "autocert").Logger()

View file

@ -17,6 +17,7 @@ import (
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
@ -57,25 +58,20 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries return p.certExpiries
} }
func (p *Provider) ObtainCert() (res E.Error) { func (p *Provider) ObtainCert() E.Error {
b := E.NewBuilder("failed to obtain certificate")
defer b.To(&res)
if p.cfg.Provider == ProviderLocal { if p.cfg.Provider == ProviderLocal {
return nil return nil
} }
if p.client == nil { if p.client == nil {
if err := p.initClient(); err.HasError() { if err := p.initClient(); err != nil {
b.Add(E.FailWith("init autocert client", err)) return err
return
} }
} }
if p.user.Registration == nil { if p.user.Registration == nil {
if err := p.registerACME(); err.HasError() { if err := p.registerACME(); err != nil {
b.Add(E.FailWith("register ACME", err)) return E.From(err)
return
} }
} }
@ -84,27 +80,23 @@ func (p *Provider) ObtainCert() (res E.Error) {
Domains: p.cfg.Domains, Domains: p.cfg.Domains,
Bundle: true, Bundle: true,
} }
cert, err := E.Check(client.Certificate.Obtain(req)) cert, err := client.Certificate.Obtain(req)
if err.HasError() { if err != nil {
b.Add(err) return E.From(err)
return
} }
if err = p.saveCert(cert); err.HasError() { if err = p.saveCert(cert); err != nil {
b.Add(E.FailWith("save certificate", err)) return E.From(err)
return
} }
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey)) tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
if err.HasError() { if err != nil {
b.Add(E.FailWith("parse obtained certificate", err)) return E.From(err)
return
} }
expiries, err := getCertExpiries(&tlsCert) expiries, err := getCertExpiries(&tlsCert)
if err.HasError() { if err != nil {
b.Add(E.FailWith("get certificate expiry", err)) return E.From(err)
return
} }
p.tlsCert = &tlsCert p.tlsCert = &tlsCert
p.certExpiries = expiries p.certExpiries = expiries
@ -113,21 +105,22 @@ func (p *Provider) ObtainCert() (res E.Error) {
} }
func (p *Provider) LoadCert() E.Error { func (p *Provider) LoadCert() E.Error {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)) cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
if err.HasError() { if err != nil {
return err return E.Errorf("load SSL certificate: %w", err)
} }
expiries, err := getCertExpiries(&cert) expiries, err := getCertExpiries(&cert)
if err.HasError() { if err != nil {
return err return E.Errorf("parse SSL certificate: %w", err)
} }
p.tlsCert = &cert p.tlsCert = &cert
p.certExpiries = expiries p.certExpiries = expiries
logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn()))) logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded() return p.renewIfNeeded()
} }
// ShouldRenewOn returns the time at which the certificate should be renewed.
func (p *Provider) ShouldRenewOn() time.Time { func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries { for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0) // 1 month before return expiry.AddDate(0, -1, 0) // 1 month before
@ -150,8 +143,8 @@ func (p *Provider) ScheduleRenewal() {
case <-task.Context().Done(): case <-task.Context().Done():
return return
case <-ticker.C: // check every 5 seconds case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() { if err := p.renewIfNeeded(); err != nil {
logger.Warn(err) E.LogWarn("cert renew failed", err, &logger)
} }
} }
} }
@ -159,31 +152,32 @@ func (p *Provider) ScheduleRenewal() {
} }
func (p *Provider) initClient() E.Error { func (p *Provider) initClient() E.Error {
legoClient, err := E.Check(lego.NewClient(p.legoCfg)) legoClient, err := lego.NewClient(p.legoCfg)
if err.HasError() { if err != nil {
return E.FailWith("create lego client", err) return E.From(err)
} }
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options) generator := providersGenMap[p.cfg.Provider]
if err.HasError() { legoProvider, pErr := generator(p.cfg.Options)
return E.FailWith("create lego provider", err) if pErr != nil {
return pErr
} }
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) err = legoClient.Challenge.SetDNS01Provider(legoProvider)
if err.HasError() { if err != nil {
return E.FailWith("set challenge provider", err) return E.From(err)
} }
p.client = legoClient p.client = legoClient
return nil return nil
} }
func (p *Provider) registerACME() E.Error { func (p *Provider) registerACME() error {
if p.user.Registration != nil { if p.user.Registration != nil {
return nil return nil
} }
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err.HasError() { if err != nil {
return err return err
} }
p.user.Registration = reg p.user.Registration = reg
@ -191,26 +185,27 @@ func (p *Provider) registerACME() E.Error {
return nil return nil
} }
func (p *Provider) saveCert(cert *certificate.Resource) E.Error { func (p *Provider) saveCert(cert *certificate.Resource) error {
/* This should have been done in setup /* This should have been done in setup
but double check is always a good choice.*/ but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath)) _, err := os.Stat(path.Dir(p.cfg.CertPath))
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil { if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
return E.FailWith("create cert directory", err) return err
} }
} else { } else {
return E.FailWith("stat cert directory", err) return err
} }
} }
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil { if err != nil {
return E.FailWith("write key file", err) return err
} }
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r-- err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil { if err != nil {
return E.FailWith("write cert file", err) return err
} }
return nil return nil
} }
@ -232,7 +227,7 @@ func (p *Provider) certState() CertState {
sort.Strings(certDomains) sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) { if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch return CertStateMismatch
} }
@ -246,25 +241,25 @@ func (p *Provider) renewIfNeeded() E.Error {
switch p.certState() { switch p.certState() {
case CertStateExpired: case CertStateExpired:
logger.Info("certs expired, renewing") logger.Info().Msg("certs expired, renewing")
case CertStateMismatch: case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing") logger.Info().Msg("cert domains mismatch with config, renewing")
default: default:
return nil return nil
} }
if err := p.ObtainCert(); err.HasError() { if err := p.ObtainCert(); err != nil {
return E.FailWith("renew certificate", err) return err
} }
return nil return nil
} }
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) { func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
r := make(CertExpiries, len(cert.Certificate)) r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate { for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert)) x509Cert, err := x509.ParseCertificate(cert)
if err.HasError() { if err != nil {
return nil, E.FailWith("parse certificate", err) return nil, err
} }
if x509Cert.IsCA { if x509Cert.IsCA {
continue continue
@ -284,13 +279,10 @@ func providerGenerator[CT any, PT challenge.Provider](
return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) { return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) {
cfg := defaultCfg() cfg := defaultCfg()
err := U.Deserialize(opt, cfg) err := U.Deserialize(opt, cfg)
if err.HasError() { if err != nil {
return nil, err return nil, err
} }
p, err := E.Check(newProvider(cfg)) p, pErr := newProvider(cfg)
if err.HasError() { return p, E.From(pErr)
return nil, err
}
return p, nil
} }
} }

View file

@ -45,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any) opt := make(map[string]any)
ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt)) ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt))
ExpectNoError(t, U.Deserialize(opt, cfg).Error()) ExpectNoError(t, U.Deserialize(opt, cfg))
ExpectDeepEqual(t, cfg, cfgExpected) ExpectDeepEqual(t, cfg, cfgExpected)
} }

View file

@ -11,7 +11,7 @@ func (p *Provider) Setup() (err E.Error) {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
return err return err
} }
logger.Debug("obtaining cert due to error loading cert") logger.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil { if err = p.ObtainCert(); err != nil {
return err return err
} }
@ -20,7 +20,7 @@ func (p *Provider) Setup() (err E.Error) {
p.ScheduleRenewal() p.ScheduleRenewal()
for _, expiry := range p.GetExpiries() { for _, expiry := range p.GetExpiries() {
logger.Infof("certificate expire on %s", expiry) logger.Info().Msg("certificate expire on " + expiry.String())
break break
} }

View file

@ -3,8 +3,7 @@ package common
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"github.com/sirupsen/logrus"
) )
type Args struct { type Args struct {
@ -44,7 +43,7 @@ func GetArgs() Args {
flag.Parse() flag.Parse()
args.Command = flag.Arg(0) args.Command = flag.Arg(0)
if err := validateArg(args.Command); err != nil { if err := validateArg(args.Command); err != nil {
logrus.Fatal(err) log.Fatalf("invalid command: %s", err)
} }
return args return args
} }
@ -55,5 +54,5 @@ func validateArg(arg string) error {
return nil return nil
} }
} }
return fmt.Errorf("invalid command: %s", arg) return fmt.Errorf("invalid command %q", arg)
} }

View file

@ -7,8 +7,6 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"github.com/sirupsen/logrus"
) )
var ( var (
@ -40,7 +38,7 @@ func GetEnvBool(key string, defaultValue bool) bool {
} }
b, err := strconv.ParseBool(value) b, err := strconv.ParseBool(value)
if err != nil { if err != nil {
log.Fatalf("Invalid boolean value: %s", value) log.Fatalf("env %s: invalid boolean value: %s", key, value)
} }
return b return b
} }
@ -57,7 +55,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
addr = GetEnv(key, defaultValue) addr = GetEnv(key, defaultValue)
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
logrus.Fatalf("Invalid address: %s", addr) log.Fatalf("env %s: invalid address: %s", key, addr)
} }
if host == "" { if host == "" {
host = "localhost" host = "localhost"

View file

@ -2,14 +2,15 @@ package config
import ( import (
"os" "os"
"strconv"
"sync" "sync"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider" proxy "github.com/yusing/go-proxy/internal/route/provider"
@ -31,7 +32,7 @@ type Config struct {
var ( var (
instance *Config instance *Config
cfgWatcher watcher.Watcher cfgWatcher watcher.Watcher
logger = logrus.WithField("module", "config") logger = logging.With().Str("module", "config").Logger()
reloadMu sync.Mutex reloadMu sync.Mutex
) )
@ -80,7 +81,7 @@ func WatchChanges() {
configEventFlushInterval, configEventFlushInterval,
OnConfigChange, OnConfigChange,
func(err E.Error) { func(err E.Error) {
logger.Error(err) E.LogError("config reload error", err, &logger)
}, },
) )
eventQueue.Start(cfgWatcher.Events(task.Context())) eventQueue.Start(cfgWatcher.Events(task.Context()))
@ -93,15 +94,16 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) {
// just reload once and check the last event // just reload once and check the last event
switch ev[len(ev)-1].Action { switch ev[len(ev)-1].Action {
case events.ActionFileRenamed: case events.ActionFileRenamed:
logger.Warn(cfgRenameWarn) logger.Warn().Msg(cfgRenameWarn)
return return
case events.ActionFileDeleted: case events.ActionFileDeleted:
logger.Warn(cfgDeleteWarn) logger.Warn().Msg(cfgDeleteWarn)
return return
} }
if err := Reload(); err != nil { if err := Reload(); err != nil {
logger.Error(err) // recovered in event queue
panic(err)
} }
} }
@ -138,63 +140,57 @@ func (cfg *Config) Task() task.Task {
} }
func (cfg *Config) StartProxyProviders() { func (cfg *Config) StartProxyProviders() {
b := E.NewBuilder("errors starting providers") errs := cfg.providers.CollectErrorsParallel(
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { func(_ string, p *proxy.Provider) error {
b.Add(p.Start(cfg.task.Subtask(p.String()))) subtask := cfg.task.Subtask(p.String())
}) return p.Start(subtask)
})
if b.HasError() { if err := E.Join(errs...); err != nil {
logger.Error(b.Build()) E.LogError("route provider errors", err, &logger)
} }
} }
func (cfg *Config) load() (res E.Error) { func (cfg *Config) load() E.Error {
errs := E.NewBuilder("errors loading config") const errMsg = "config load error"
defer errs.To(&res)
logger.Debug("loading config") data, err := os.ReadFile(common.ConfigPath)
defer logger.Debug("loaded config")
data, err := E.Check(os.ReadFile(common.ConfigPath))
if err != nil { if err != nil {
errs.Add(E.FailWith("read config", err)) E.LogFatal(errMsg, err, &logger)
logrus.Fatal(errs.Build())
} }
if !common.NoSchemaValidation { if !common.NoSchemaValidation {
if err = Validate(data); err != nil { if err := Validate(data); err != nil {
errs.Add(E.FailWith("schema validation", err)) E.LogFatal(errMsg, err, &logger)
logrus.Fatal(errs.Build())
} }
} }
model := types.DefaultConfig() model := types.DefaultConfig()
if err := E.From(yaml.Unmarshal(data, model)); err != nil { if err := E.From(yaml.Unmarshal(data, model)); err != nil {
errs.Add(E.FailWith("parse config", err)) E.LogFatal(errMsg, err, &logger)
logrus.Fatal(errs.Build())
} }
// errors are non fatal below // errors are non fatal below
errs := E.NewBuilder(errMsg)
errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initNotification(model.Providers.Notification))
errs.Add(cfg.initAutoCert(&model.AutoCert)) errs.Add(cfg.initAutoCert(&model.AutoCert))
errs.Add(cfg.loadRouteProviders(&model.Providers)) errs.Add(cfg.loadRouteProviders(&model.Providers))
cfg.value = model cfg.value = model
route.SetFindMuxDomains(model.MatchDomains) route.SetFindMuxDomains(model.MatchDomains)
return return errs.Error()
} }
func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) { func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) {
if len(notifCfgMap) == 0 { if len(notifCfgMap) == 0 {
return return
} }
errs := E.NewBuilder("errors initializing notification providers") errs := E.NewBuilder("notification providers load errors")
for name, notifCfg := range notifCfgMap { for name, notifCfg := range notifCfgMap {
_, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg) _, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg)
errs.Add(err) errs.Add(err)
} }
return errs.Build() return errs.Error()
} }
func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) { func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) {
@ -203,40 +199,45 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error)
} }
cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider()
if err != nil {
err = E.FailWith("autocert provider", err)
}
return return
} }
func (cfg *Config) loadRouteProviders(providers *types.Providers) (outErr E.Error) { func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error {
subtask := cfg.task.Subtask("load route providers") subtask := cfg.task.Subtask("load route providers")
defer subtask.Finish("done") defer subtask.Finish("done")
errs := E.NewBuilder("errors loading route providers") errs := E.NewBuilder("route provider errors")
results := E.NewBuilder("loaded providers") results := E.NewBuilder("loaded route providers")
defer errs.To(&outErr)
lenLongestName := 0
for _, filename := range providers.Files { for _, filename := range providers.Files {
p, err := proxy.NewFileProvider(filename) p, err := proxy.NewFileProvider(filename)
if err != nil { if err != nil {
errs.Add(err) errs.Add(E.PrependSubject(filename, err))
continue continue
} }
cfg.providers.Store(p.GetName(), p) cfg.providers.Store(p.GetName(), p)
errs.Add(p.LoadRoutes().Subject(filename)) if len(p.GetName()) > lenLongestName {
results.Addf("%d routes from %s", p.NumRoutes(), p.String()) lenLongestName = len(p.GetName())
}
} }
for name, dockerHost := range providers.Docker { for name, dockerHost := range providers.Docker {
p, err := proxy.NewDockerProvider(name, dockerHost) p, err := proxy.NewDockerProvider(name, dockerHost)
if err != nil { if err != nil {
errs.Add(err.Subjectf("%s (%s)", name, dockerHost)) errs.Add(E.PrependSubject(name, err))
continue continue
} }
cfg.providers.Store(p.GetName(), p) cfg.providers.Store(p.GetName(), p)
errs.Add(p.LoadRoutes().Subject(p.GetName())) if len(p.GetName()) > lenLongestName {
results.Addf("%d routes from %s", p.NumRoutes(), p.String()) lenLongestName = len(p.GetName())
}
} }
logger.Info(results.Build()) cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
return if err := p.LoadRoutes(); err != nil {
errs.Add(err.Subject(p.String()))
}
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes())
})
logger.Info().Msg(results.String())
return errs.Error()
} }

View file

@ -9,8 +9,8 @@ import (
"github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
proxy "github.com/yusing/go-proxy/internal/route/provider" proxy "github.com/yusing/go-proxy/internal/route/provider"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
func DumpEntries() map[string]*entry.RawEntry { func DumpEntries() map[string]*entry.RawEntry {
@ -61,7 +61,7 @@ func HomepageConfig() homepage.Config {
} }
if item.Name == "" { if item.Name == "" {
item.Name = U.Title( item.Name = strutils.Title(
strings.ReplaceAll( strings.ReplaceAll(
strings.ReplaceAll(alias, "-", " "), strings.ReplaceAll(alias, "-", " "),
"_", " ", "_", " ",

View file

@ -1,14 +1,15 @@
package docker package docker
import ( import (
"errors"
"net/http" "net/http"
"sync" "sync"
"github.com/docker/cli/cli/connhelper" "github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/sirupsen/logrus" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
@ -22,7 +23,7 @@ type (
key string key string
refCount *U.RefCount refCount *U.RefCount
l logrus.FieldLogger l zerolog.Logger
} }
) )
@ -70,7 +71,7 @@ func (c *SharedClient) Close() error {
// Returns: // Returns:
// - Client: the Docker client connection. // - Client: the Docker client connection.
// - error: an error if the connection failed. // - error: an error if the connection failed.
func ConnectClient(host string) (Client, E.Error) { func ConnectClient(host string) (Client, error) {
clientMapMu.Lock() clientMapMu.Lock()
defer clientMapMu.Unlock() defer clientMapMu.Unlock()
@ -85,13 +86,13 @@ func ConnectClient(host string) (Client, E.Error) {
switch host { switch host {
case "": case "":
return nil, E.Invalid("docker host", "empty") return nil, errors.New("empty docker host")
case common.DockerHostFromEnv: case common.DockerHostFromEnv:
opt = clientOptEnvHost opt = clientOptEnvHost
default: default:
helper, err := E.Check(connhelper.GetConnectionHelper(host)) helper, err := connhelper.GetConnectionHelper(host)
if err.HasError() { if err != nil {
return nil, E.UnexpectedError(err.Error()) logging.Panic().Err(err).Msg("failed to get connection helper")
} }
if helper != nil { if helper != nil {
httpClient := &http.Client{ httpClient := &http.Client{
@ -113,9 +114,9 @@ func ConnectClient(host string) (Client, E.Error) {
} }
} }
client, err := E.Check(client.NewClientWithOpts(opt...)) client, err := client.NewClientWithOpts(opt...)
if err.HasError() { if err != nil {
return nil, err return nil, err
} }
@ -123,9 +124,9 @@ func ConnectClient(host string) (Client, E.Error) {
Client: client, Client: client,
key: host, key: host,
refCount: U.NewRefCounter(), refCount: U.NewRefCounter(),
l: logger.WithField("docker_client", client.DaemonHost()), l: logger.With().Str("address", client.DaemonHost()).Logger(),
} }
c.l.Debugf("client connected") c.l.Trace().Msg("client connected")
clientMap.Store(host, c) clientMap.Store(host, c)
@ -135,7 +136,7 @@ func ConnectClient(host string) (Client, E.Error) {
if c.Connected() { if c.Connected() {
c.Client.Close() c.Client.Close()
c.l.Debugf("client closed") c.l.Trace().Msg("client closed")
} }
}() }()
return c, nil return c, nil

View file

@ -6,8 +6,8 @@ import (
"strings" "strings"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
@ -59,7 +59,7 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) {
NetworkMode: c.HostConfig.NetworkMode, NetworkMode: c.HostConfig.NetworkMode,
Aliases: helper.getAliases(), Aliases: helper.getAliases(),
IsExcluded: U.ParseBool(helper.getDeleteLabel(LabelExclude)), IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)),
IsExplicit: isExplicit, IsExplicit: isExplicit,
IsDatabase: helper.isDatabase(), IsDatabase: helper.isDatabase(),
IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout),
@ -120,7 +120,7 @@ func (c *Container) setPublicIP() {
} }
url, err := url.Parse(c.DockerHost) url, err := url.Parse(c.DockerHost)
if err != nil { if err != nil {
logrus.Errorf("invalid docker host %q: %v\nfalling back to 127.0.0.1", c.DockerHost, err) logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
c.PublicIP = "127.0.0.1" c.PublicIP = "127.0.0.1"
return return
} }

View file

@ -4,7 +4,7 @@ import (
"strings" "strings"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type containerHelper struct { type containerHelper struct {
@ -23,7 +23,7 @@ func (c containerHelper) getDeleteLabel(label string) string {
func (c containerHelper) getAliases() []string { func (c containerHelper) getAliases() []string {
if l := c.getDeleteLabel(LabelAliases); l != "" { if l := c.getDeleteLabel(LabelAliases); l != "" {
return U.CommaSeperatedList(l) return strutils.CommaSeperatedList(l)
} }
return []string{c.getName()} return []string{c.getName()}
} }
@ -44,7 +44,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping {
if v.PublicPort == 0 { if v.PublicPort == 0 {
continue continue
} }
res[U.PortString(v.PublicPort)] = v res[strutils.PortString(v.PublicPort)] = v
} }
return res return res
} }
@ -52,7 +52,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping {
func (c containerHelper) getPrivatePortMapping() PortMapping { func (c containerHelper) getPrivatePortMapping() PortMapping {
res := make(PortMapping) res := make(PortMapping)
for _, v := range c.Ports { for _, v := range c.Ports {
res[U.PortString(v.PrivatePort)] = v res[strutils.PortString(v.PrivatePort)] = v
} }
return res return res
} }

View file

@ -1,6 +1,7 @@
package types package types
import ( import (
"errors"
"time" "time"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
@ -30,7 +31,7 @@ const (
StopMethodKill StopMethod = "kill" StopMethodKill StopMethod = "kill"
) )
func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { func ValidateConfig(cont *docker.Container) (*Config, E.Error) {
if cont == nil { if cont == nil {
return nil, nil return nil, nil
} }
@ -44,26 +45,16 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
}, nil }, nil
} }
b := E.NewBuilder("invalid idlewatcher config") errs := E.NewBuilder("invalid idlewatcher config")
defer b.To(&res)
idleTimeout, err := validateDurationPostitive(cont.IdleTimeout) idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout)
b.Add(err.Subjectf("%s", "idle_timeout")) wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout)
stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout)
stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod)
signal := E.Collect(errs, validateSignal, cont.StopSignal)
wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout) if errs.HasError() {
b.Add(err.Subjectf("%s", "wake_timeout")) return nil, errs.Error()
stopTimeout, err := validateDurationPostitive(cont.StopTimeout)
b.Add(err.Subjectf("%s", "stop_timeout"))
stopMethod, err := validateStopMethod(cont.StopMethod)
b.Add(err)
signal, err := validateSignal(cont.StopSignal)
b.Add(err)
if err := b.Build(); err != nil {
return
} }
return &Config{ return &Config{
@ -80,33 +71,33 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) {
}, nil }, nil
} }
func validateDurationPostitive(value string) (time.Duration, E.Error) { func validateDurationPostitive(value string) (time.Duration, error) {
d, err := time.ParseDuration(value) d, err := time.ParseDuration(value)
if err != nil { if err != nil {
return 0, E.Invalid("duration", value).With(err) return 0, err
} }
if d < 0 { if d < 0 {
return 0, E.Invalid("duration", "negative value") return 0, errors.New("duration must be positive")
} }
return d, nil return d, nil
} }
func validateSignal(s string) (Signal, E.Error) { func validateSignal(s string) (Signal, error) {
switch s { switch s {
case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT",
"INT", "TERM", "HUP", "QUIT": "INT", "TERM", "HUP", "QUIT":
return Signal(s), nil return Signal(s), nil
} }
return "", E.Invalid("signal", s) return "", errors.New("invalid signal " + s)
} }
func validateStopMethod(s string) (StopMethod, E.Error) { func validateStopMethod(s string) (StopMethod, error) {
sm := StopMethod(s) sm := StopMethod(s)
switch sm { switch sm {
case StopMethodPause, StopMethodStop, StopMethodKill: case StopMethodPause, StopMethodStop, StopMethodKill:
return sm, nil return sm, nil
default: default:
return "", E.Invalid("stop_method", sm) return "", errors.New("invalid stop method " + s)
} }
} }

View file

@ -42,7 +42,7 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr
watcher, err := registerWatcher(providerSubTask, entry, waker) watcher, err := registerWatcher(providerSubTask, entry, waker)
if err != nil { if err != nil {
return nil, err return nil, E.Errorf("register watcher: %w", err)
} }
if rp != nil { if rp != nil {
@ -75,6 +75,9 @@ func (w *Watcher) Start(routeSubTask task.Task) E.Error {
// Finish implements health.HealthMonitor. // Finish implements health.HealthMonitor.
func (w *Watcher) Finish(reason any) { func (w *Watcher) Finish(reason any) {
if w.stream != nil {
w.stream.Close()
}
} }
// Name implements health.HealthMonitor. // Name implements health.HealthMonitor.

View file

@ -7,9 +7,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -20,21 +18,26 @@ func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if !shouldNext { if !shouldNext {
return return
} }
w.rp.ServeHTTP(rw, r) select {
case <-r.Context().Done():
return
default:
w.rp.ServeHTTP(rw, r)
}
} }
func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) {
w.resetIdleTimer() w.resetIdleTimer()
if r.Body != nil {
defer r.Body.Close()
}
// pass through if container is already ready // pass through if container is already ready
if w.ready.Load() { if w.ready.Load() {
return true return true
} }
if r.Body != nil {
defer r.Body.Close()
}
accept := gphttp.GetAccept(r.Header) accept := gphttp.GetAccept(r.Header)
acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty())
@ -49,23 +52,22 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
rw.Header().Add("Cache-Control", "must-revalidate") rw.Header().Add("Cache-Control", "must-revalidate")
rw.Header().Add("Connection", "close") rw.Header().Add("Connection", "close")
if _, err := rw.Write(body); err != nil { if _, err := rw.Write(body); err != nil {
w.l.Errorf("error writing http response: %s", err) w.Err(err).Msg("error writing http response")
} }
return return false
} }
ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout")) ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout"))
defer cancel() defer cancel()
checkCanceled := func() bool { checkCanceled := func() (canceled bool) {
select { select {
case <-w.task.Context().Done():
w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context()))
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true
case <-ctx.Done(): case <-ctx.Done():
w.l.Debugf("wake canceled: %s", context.Cause(ctx)) w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled")
http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) return true
case <-w.task.Context().Done():
w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled")
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
return true return true
default: default:
return false return false
@ -76,12 +78,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
return false return false
} }
w.l.Debug("wake signal received") w.WakeTrace().Msg("signal received")
err := w.wakeIfStopped() err := w.wakeIfStopped()
if err != nil { if err != nil {
w.l.Error(E.FailWith("wake", err)) w.WakeError(err).Send()
http.Error(rw, "Error waking container", http.StatusInternalServerError) http.Error(rw, "Error waking container", http.StatusInternalServerError)
return return false
} }
for { for {
@ -92,11 +94,11 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
if w.Status() == health.StatusHealthy { if w.Status() == health.StatusHealthy {
w.resetIdleTimer() w.resetIdleTimer()
if isCheckRedirect { if isCheckRedirect {
logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL()) w.Debug().Msgf("redirecting to %s ...", w.hc.URL())
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
return return false
} }
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) w.Debug().Msgf("passing through to %s ...", w.hc.URL())
return true return true
} }

View file

@ -7,7 +7,7 @@ import (
"net" "net"
"time" "time"
"github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -16,25 +16,25 @@ func (w *Watcher) Addr() net.Addr {
return w.stream.Addr() return w.stream.Addr()
} }
// Setup implements types.Stream.
func (w *Watcher) Setup() error { func (w *Watcher) Setup() error {
return w.stream.Setup() return w.stream.Setup()
} }
// Accept implements types.Stream. // Accept implements types.Stream.
func (w *Watcher) Accept() (conn net.Conn, err error) { func (w *Watcher) Accept() (conn types.StreamConn, err error) {
conn, err = w.stream.Accept() conn, err = w.stream.Accept()
if err != nil { if err != nil {
logrus.Errorf("accept failed with error: %s", err)
return return
} }
if err := w.wakeFromStream(); err != nil { if wakeErr := w.wakeFromStream(); wakeErr != nil {
w.l.Error(err) w.WakeError(wakeErr).Msg("error waking from stream")
} }
return return
} }
// Handle implements types.Stream. // Handle implements types.Stream.
func (w *Watcher) Handle(conn net.Conn) error { func (w *Watcher) Handle(conn types.StreamConn) error {
if err := w.wakeFromStream(); err != nil { if err := w.wakeFromStream(); err != nil {
return err return err
} }
@ -54,11 +54,11 @@ func (w *Watcher) wakeFromStream() error {
return nil return nil
} }
w.l.Debug("wake signal received") w.WakeDebug().Msg("wake signal received")
wakeErr := w.wakeIfStopped() wakeErr := w.wakeIfStopped()
if wakeErr != nil { if wakeErr != nil {
wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr) wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr)
w.l.Error(wakeErr) w.WakeError(wakeErr).Msg("wake failed")
return wakeErr return wakeErr
} }
@ -69,18 +69,18 @@ func (w *Watcher) wakeFromStream() error {
select { select {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
cause := w.task.FinishCause() cause := w.task.FinishCause()
w.l.Debugf("wake canceled: %s", cause) w.WakeDebug().Str("cause", cause.Error()).Msg("canceled")
return cause return cause
case <-ctx.Done(): case <-ctx.Done():
cause := context.Cause(ctx) cause := context.Cause(ctx)
w.l.Debugf("wake canceled: %s", cause) w.WakeDebug().Str("cause", cause.Error()).Msg("timeout")
return cause return cause
default: default:
} }
if w.Status() == health.StatusHealthy { if w.Status() == health.StatusHealthy {
w.resetIdleTimer() w.resetIdleTimer()
logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String())
return nil return nil
} }

View file

@ -3,15 +3,15 @@ package idlewatcher
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/sirupsen/logrus" "github.com/rs/zerolog"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/proxy/entry"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
@ -25,6 +25,8 @@ type (
Watcher struct { Watcher struct {
_ U.NoCopy _ U.NoCopy
zerolog.Logger
*idlewatcher.Config *idlewatcher.Config
*waker *waker
@ -32,7 +34,6 @@ type (
stopByMethod StopCallback // send a docker command w.r.t. `stop_method` stopByMethod StopCallback // send a docker command w.r.t. `stop_method`
ticker *time.Ticker ticker *time.Ticker
task task.Task task task.Task
l *logrus.Entry
} }
WakeDone <-chan error WakeDone <-chan error
@ -44,13 +45,12 @@ var (
watcherMap = F.NewMapOf[string, *Watcher]() watcherMap = F.NewMapOf[string, *Watcher]()
watcherMapMu sync.Mutex watcherMapMu sync.Mutex
logger = logrus.WithField("module", "idle_watcher") logger = logging.With().Str("module", "idle_watcher").Logger()
) )
const dockerReqTimeout = 3 * time.Second const dockerReqTimeout = 3 * time.Second
func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) { func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) {
failure := E.Failure("idle_watcher register")
cfg := entry.IdlewatcherConfig() cfg := entry.IdlewatcherConfig()
if cfg.IdleTimeout == 0 { if cfg.IdleTimeout == 0 {
@ -71,17 +71,17 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
} }
client, err := D.ConnectClient(cfg.DockerHost) client, err := D.ConnectClient(cfg.DockerHost)
if err.HasError() { if err != nil {
return nil, failure.With(err) return nil, err
} }
w := &Watcher{ w := &Watcher{
Logger: logger.With().Str("name", cfg.ContainerName).Logger(),
Config: cfg, Config: cfg,
waker: waker, waker: waker,
client: client, client: client,
task: providerSubtask, task: providerSubtask,
ticker: time.NewTicker(cfg.IdleTimeout), ticker: time.NewTicker(cfg.IdleTimeout),
l: logger.WithField("container", cfg.ContainerName),
} }
w.stopByMethod = w.getStopCallback() w.stopByMethod = w.getStopCallback()
watcherMap.Store(key, w) watcherMap.Store(key, w)
@ -99,6 +99,23 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker)
return w, nil return w, nil
} }
// WakeDebug logs a debug message related to waking the container.
func (w *Watcher) WakeDebug() *zerolog.Event {
return w.Debug().Str("action", "wake")
}
func (w *Watcher) WakeTrace() *zerolog.Event {
return w.Trace().Str("action", "wake")
}
func (w *Watcher) WakeError(err error) *zerolog.Event {
return w.Err(err).Str("action", "wake")
}
func (w *Watcher) LogReason(action, reason string) {
w.Info().Str("reason", reason).Msg(action)
}
func (w *Watcher) containerStop(ctx context.Context) error { func (w *Watcher) containerStop(ctx context.Context) error {
return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{ return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{
Signal: string(w.StopSignal), Signal: string(w.StopSignal),
@ -130,7 +147,7 @@ func (w *Watcher) containerStatus() (string, error) {
defer cancel() defer cancel()
json, err := w.client.ContainerInspect(ctx, w.ContainerID) json, err := w.client.ContainerInspect(ctx, w.ContainerID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to inspect container: %w", err) return "", err
} }
return json.State.Status, nil return json.State.Status, nil
} }
@ -181,7 +198,7 @@ func (w *Watcher) getStopCallback() StopCallback {
} }
func (w *Watcher) resetIdleTimer() { func (w *Watcher) resetIdleTimer() {
w.l.Trace("reset idle timer") w.Trace().Msg("reset idle timer")
w.ticker.Reset(w.IdleTimeout) w.ticker.Reset(w.IdleTimeout)
} }
@ -190,7 +207,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{
Filters: W.NewDockerFilter( Filters: W.NewDockerFilter(
W.DockerFilterContainer, W.DockerFilterContainer,
W.DockerrFilterContainer(w.ContainerID), W.DockerFilterContainerNameID(w.ContainerID),
W.DockerFilterStart, W.DockerFilterStart,
W.DockerFilterStop, W.DockerFilterStop,
W.DockerFilterDie, W.DockerFilterDie,
@ -214,7 +231,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas
// //
// it exits only if the context is canceled, the container is destroyed, // it exits only if the context is canceled, the container is destroyed,
// errors occured on docker client, or route provider died (mainly caused by config reload). // errors occured on docker client, or route provider died (mainly caused by config reload).
func (w *Watcher) watchUntilDestroy() error { func (w *Watcher) watchUntilDestroy() (returnCause error) {
dockerWatcher := W.NewDockerWatcherWithClient(w.client) dockerWatcher := W.NewDockerWatcherWithClient(w.client)
eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher)
defer eventTask.Finish("stopped") defer eventTask.Finish("stopped")
@ -224,36 +241,36 @@ func (w *Watcher) watchUntilDestroy() error {
case <-w.task.Context().Done(): case <-w.task.Context().Done():
return w.task.FinishCause() return w.task.FinishCause()
case err := <-dockerEventErrCh: case err := <-dockerEventErrCh:
if err != nil && err.IsNot(context.Canceled) { if !err.Is(context.Canceled) {
w.l.Error(E.FailWith("docker watcher", err)) E.LogError("idlewatcher error", err, &w.Logger)
return err.Error()
} }
return err
case e := <-dockerEventCh: case e := <-dockerEventCh:
switch { switch {
case e.Action == events.ActionContainerDestroy: case e.Action == events.ActionContainerDestroy:
w.ContainerRunning = false w.ContainerRunning = false
w.ready.Store(false) w.ready.Store(false)
w.l.Info("watcher stopped by container destruction") w.LogReason("watcher stopped", "container destroyed")
return errors.New("container destroyed") return errors.New("container destroyed")
// create / start / unpause // create / start / unpause
case e.Action.IsContainerWake(): case e.Action.IsContainerWake():
w.ContainerRunning = true w.ContainerRunning = true
w.resetIdleTimer() w.resetIdleTimer()
w.l.Info("container awaken") w.Info().Msg("awaken")
case e.Action.IsContainerSleep(): // stop / pause / kil case e.Action.IsContainerSleep(): // stop / pause / kil
w.ContainerRunning = false w.ContainerRunning = false
w.ready.Store(false) w.ready.Store(false)
w.ticker.Stop() w.ticker.Stop()
default: default:
w.l.Errorf("unexpected docker event: %s", e) w.Error().Msg("unexpected docker event: " + e.String())
} }
// container name changed should also change the container id // container name changed should also change the container id
if w.ContainerName != e.ActorName { if w.ContainerName != e.ActorName {
w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName) w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName)
w.ContainerName = e.ActorName w.ContainerName = e.ActorName
} }
if w.ContainerID != e.ActorID { if w.ContainerID != e.ActorID {
w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID) w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID)
w.ContainerID = e.ActorID w.ContainerID = e.ActorID
// recreate event stream // recreate event stream
eventTask.Finish("recreate event stream") eventTask.Finish("recreate event stream")
@ -263,9 +280,9 @@ func (w *Watcher) watchUntilDestroy() error {
w.ticker.Stop() w.ticker.Stop()
if w.ContainerRunning { if w.ContainerRunning {
if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) { if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) {
w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err) w.Err(err).Msgf("container stop with method %q failed", w.StopMethod)
} else { } else {
w.l.Info("container stopped by idle timeout") w.LogReason("container stopped", "idle timeout")
} }
} }
} }

View file

@ -4,28 +4,26 @@ import (
"context" "context"
"errors" "errors"
"time" "time"
E "github.com/yusing/go-proxy/internal/error"
) )
func Inspect(dockerHost string, containerID string) (*Container, E.Error) { func Inspect(dockerHost string, containerID string) (*Container, error) {
client, err := ConnectClient(dockerHost) client, err := ConnectClient(dockerHost)
defer client.Close() defer client.Close()
if err.HasError() { if err != nil {
return nil, E.FailWith("connect to docker", err) return nil, err
} }
return client.Inspect(containerID) return client.Inspect(containerID)
} }
func (c Client) Inspect(containerID string) (*Container, E.Error) { func (c Client) Inspect(containerID string) (*Container, error) {
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout")) ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout"))
defer cancel() defer cancel()
json, err := c.ContainerInspect(ctx, containerID) json, err := c.ContainerInspect(ctx, containerID)
if err != nil { if err != nil {
return nil, E.From(err) return nil, err
} }
return FromJSON(json, c.key), nil return FromJSON(json, c.key), nil
} }

View file

@ -24,6 +24,11 @@ type (
NestedLabelMap map[string]U.SerializedObject NestedLabelMap map[string]U.SerializedObject
) )
var (
ErrApplyToNil = E.New("label value is nil")
ErrFieldNotExist = E.New("field does not exist")
)
func (l *Label) String() string { func (l *Label) String() string {
if l.Attribute == "" { if l.Attribute == "" {
return l.Namespace + "." + l.Target return l.Namespace + "." + l.Target
@ -41,7 +46,7 @@ func (l *Label) String() string {
// - error: an error if the field does not exist. // - error: an error if the field does not exist.
func ApplyLabel[T any](obj *T, l *Label) E.Error { func ApplyLabel[T any](obj *T, l *Label) E.Error {
if obj == nil { if obj == nil {
return E.Invalid("nil object", l) return ErrApplyToNil.Subject(l.String())
} }
switch nestedLabel := l.Value.(type) { switch nestedLabel := l.Value.(type) {
case *Label: case *Label:
@ -54,7 +59,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
} }
} }
if !field.IsValid() { if !field.IsValid() {
return E.NotExist("field", l.Attribute) return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String())
} }
dst, ok := field.Interface().(NestedLabelMap) dst, ok := field.Interface().(NestedLabelMap)
if !ok { if !ok {
@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
} else { } else {
field = field.Addr() field = field.Addr()
} }
return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface()) err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface())
if err != nil {
return err.Subject(l.String())
}
return nil
} }
if dst == nil { if dst == nil {
field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]())) field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]()))
@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error {
dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value
return nil return nil
default: default:
return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj) err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj)
if err != nil {
return err.Subject(l.String())
}
return nil
} }
} }
func ParseLabel(label string, value string) (*Label, E.Error) { func ParseLabel(label string, value string) *Label {
parts := strings.Split(label, ".") parts := strings.Split(label, ".")
if len(parts) < 2 { if len(parts) < 2 {
return &Label{ return &Label{
Namespace: label, Namespace: label,
Value: value, Value: value,
}, nil }
} }
l := &Label{ l := &Label{
@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.Error) {
l.Attribute = parts[2] l.Attribute = parts[2]
default: default:
l.Attribute = parts[2] l.Attribute = parts[2]
nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value) nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value)
if err != nil {
return nil, err
}
l.Value = nestedLabel l.Value = nestedLabel
} }
return l, nil return l
} }

View file

@ -20,9 +20,8 @@ func makeLabel(ns, name, attr string) string {
func TestNestedLabel(t *testing.T) { func TestNestedLabel(t *testing.T) {
mAttr := "prop1" mAttr := "prop1"
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error()) sGot := ExpectType[*Label](t, lbl.Value)
sGot := ExpectType[*Label](t, pl.Value)
ExpectFalse(t, sGot == nil) ExpectFalse(t, sGot == nil)
ExpectEqual(t, sGot.Namespace, mName) ExpectEqual(t, sGot.Namespace, mName)
ExpectEqual(t, sGot.Attribute, mAttr) ExpectEqual(t, sGot.Attribute, mAttr)
@ -32,10 +31,9 @@ func TestApplyNestedLabel(t *testing.T) {
entry := new(struct { entry := new(struct {
Middlewares NestedLabelMap `yaml:"middlewares"` Middlewares NestedLabelMap `yaml:"middlewares"`
}) })
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error()) err := ApplyLabel(entry, lbl)
err = ApplyLabel(entry, pl) ExpectNoError(t, err)
ExpectNoError(t, err.Error())
middleware1, ok := entry.Middlewares[mName] middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok) ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr]) got := ExpectType[string](t, middleware1[mAttr])
@ -52,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) {
entry.Middlewares[mName] = make(U.SerializedObject) entry.Middlewares[mName] = make(U.SerializedObject)
entry.Middlewares[mName][checkAttr] = checkV entry.Middlewares[mName][checkAttr] = checkV
pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v)
ExpectNoError(t, err.Error()) err := ApplyLabel(entry, lbl)
err = ApplyLabel(entry, pl) ExpectNoError(t, err)
ExpectNoError(t, err.Error())
middleware1, ok := entry.Middlewares[mName] middleware1, ok := entry.Middlewares[mName]
ExpectTrue(t, ok) ExpectTrue(t, ok)
got := ExpectType[string](t, middleware1[mAttr]) got := ExpectType[string](t, middleware1[mAttr])
@ -74,10 +71,9 @@ func TestApplyNestedLabelNoAttr(t *testing.T) {
entry.Middlewares = make(NestedLabelMap) entry.Middlewares = make(NestedLabelMap)
entry.Middlewares[mName] = make(U.SerializedObject) entry.Middlewares[mName] = make(U.SerializedObject)
pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v) lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v)
ExpectNoError(t, err.Error()) err := ApplyLabel(entry, lbl)
err = ApplyLabel(entry, pl) ExpectNoError(t, err)
ExpectNoError(t, err.Error())
_, ok := entry.Middlewares[mName] _, ok := entry.Middlewares[mName]
ExpectTrue(t, ok) ExpectTrue(t, ok)
} }

View file

@ -8,7 +8,6 @@ import (
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
E "github.com/yusing/go-proxy/internal/error"
) )
var listOptions = container.ListOptions{ var listOptions = container.ListOptions{
@ -23,19 +22,19 @@ var listOptions = container.ListOptions{
All: true, All: true,
} }
func ListContainers(clientHost string) ([]types.Container, E.Error) { func ListContainers(clientHost string) ([]types.Container, error) {
dockerClient, err := ConnectClient(clientHost) dockerClient, err := ConnectClient(clientHost)
if err.HasError() { if err != nil {
return nil, E.FailWith("connect to docker", err) return nil, err
} }
defer dockerClient.Close() defer dockerClient.Close()
ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout")) ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout"))
defer cancel() defer cancel()
containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions)) containers, err := dockerClient.ContainerList(ctx, listOptions)
if err.HasError() { if err != nil {
return nil, E.FailWith("list containers", err) return nil, err
} }
return containers, nil return containers, nil
} }

View file

@ -1,5 +1,7 @@
package docker package docker
import "github.com/sirupsen/logrus" import (
"github.com/yusing/go-proxy/internal/logging"
)
var logger = logrus.WithField("module", "docker") var logger = logging.With().Str("module", "docker").Logger()

46
internal/error/base.go Normal file
View file

@ -0,0 +1,46 @@
package error
import (
"errors"
"fmt"
)
// baseError is an immutable wrapper around an error.
type baseError struct {
Err error `json:"err"`
}
func (err *baseError) Unwrap() error {
return err.Err
}
func (err *baseError) Is(other error) bool {
if other, ok := other.(*baseError); ok {
return errors.Is(err.Err, other.Err)
}
return errors.Is(err.Err, other)
}
func (err baseError) Subject(subject string) Error {
err.Err = PrependSubject(subject, err.Err)
return &err
}
func (err *baseError) Subjectf(format string, args ...any) Error {
if len(args) > 0 {
return err.Subject(fmt.Sprintf(format, args...))
}
return err.Subject(format)
}
func (err baseError) With(extra error) Error {
return &nestedError{&err, []error{extra}}
}
func (err baseError) Withf(format string, args ...any) Error {
return &nestedError{&err, []error{fmt.Errorf(format, args...)}}
}
func (err *baseError) Error() string {
return err.Err.Error()
}

View file

@ -1,104 +1,104 @@
package error package error
import ( import (
"errors"
"fmt" "fmt"
"sync" "sync"
) )
type Builder struct { type Builder struct {
*builder about string
} errs []error
type builder struct {
message string
errors []Error
sync.Mutex sync.Mutex
} }
func NewBuilder(format string, args ...any) Builder { func NewBuilder(about string) *Builder {
if len(args) > 0 { return &Builder{about: about}
return Builder{&builder{message: fmt.Sprintf(format, args...)}} }
func (b *Builder) About() string {
if !b.HasError() {
return ""
} }
return Builder{&builder{message: format}} return b.about
}
//go:inline
func (b *Builder) HasError() bool {
return len(b.errs) > 0
}
func (b *Builder) Error() Error {
if !b.HasError() {
return nil
}
if len(b.errs) == 1 {
return From(b.errs[0])
}
return &nestedError{Err: New(b.about), Extras: b.errs}
}
func (b *Builder) String() string {
if !b.HasError() {
return ""
}
return (&nestedError{Err: New(b.about), Extras: b.errs}).Error()
} }
// Add adds an error to the Builder. // Add adds an error to the Builder.
// //
// adding nil is no-op, // adding nil is no-op,
// func (b *Builder) Add(err error) *Builder {
// flatten is a boolean flag to flatten the NestedError. if err == nil {
func (b Builder) Add(err Error, flatten ...bool) { return b
if err != nil { }
b.Lock()
if len(flatten) > 0 && flatten[0] { b.Lock()
for _, e := range err.extras { defer b.Unlock()
b.errors = append(b.errors, &e)
} switch err := err.(type) {
case *baseError:
b.errs = append(b.errs, err.Err)
case *nestedError:
if err.Err == nil {
b.errs = append(b.errs, err.Extras...)
} else { } else {
b.errors = append(b.errors, err) b.errs = append(b.errs, err)
} }
b.Unlock()
}
}
func (b Builder) AddE(err error) {
b.Add(From(err))
}
func (b Builder) Addf(format string, args ...any) {
if len(args) > 0 {
b.Add(errorf(format, args...))
} else {
b.AddE(errors.New(format))
}
}
func (b Builder) AddRange(errs ...Error) {
b.Lock()
defer b.Unlock()
for _, err := range errs {
b.errors = append(b.errors, err)
}
}
func (b Builder) AddRangeE(errs ...error) {
b.Lock()
defer b.Unlock()
for _, err := range errs {
b.errors = append(b.errors, From(err))
}
}
// Build builds a NestedError based on the errors collected in the Builder.
//
// If there are no errors in the Builder, it returns a Nil() NestedError.
// Otherwise, it returns a NestedError with the message and the errors collected.
//
// Returns:
// - NestedError: the built NestedError.
func (b Builder) Build() Error {
if len(b.errors) == 0 {
return nil
}
return Join(b.message, b.errors...)
}
func (b Builder) To(ptr *Error) {
switch {
case ptr == nil:
return
case *ptr == nil:
*ptr = b.Build()
default: default:
(*ptr).extras = append((*ptr).extras, *b.Build()) b.errs = append(b.errs, err)
} }
return b
} }
func (b Builder) String() string { func (b *Builder) Adds(err string) *Builder {
return b.Build().String() b.Lock()
defer b.Unlock()
b.errs = append(b.errs, newError(err))
return b
} }
func (b Builder) HasError() bool { func (b *Builder) Addf(format string, args ...any) *Builder {
return len(b.errors) > 0 if len(args) > 0 {
b.Lock()
defer b.Unlock()
b.errs = append(b.errs, fmt.Errorf(format, args...))
} else {
b.Adds(format)
}
return b
}
func (b *Builder) AddRange(errs ...error) *Builder {
b.Lock()
defer b.Unlock()
for _, err := range errs {
if err != nil {
b.errs = append(b.errs, err)
}
}
return b
} }

View file

@ -1,6 +1,9 @@
package error_test package error_test
import ( import (
"context"
"errors"
"io"
"testing" "testing"
. "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/error"
@ -8,14 +11,13 @@ import (
) )
func TestBuilderEmpty(t *testing.T) { func TestBuilderEmpty(t *testing.T) {
eb := NewBuilder("qwer") eb := NewBuilder("foo")
ExpectTrue(t, eb.Build() == nil) ExpectTrue(t, errors.Is(eb.Error(), nil))
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError()) ExpectFalse(t, eb.HasError())
} }
func TestBuilderAddNil(t *testing.T) { func TestBuilderAddNil(t *testing.T) {
eb := NewBuilder("asdf") eb := NewBuilder("foo")
var err Error var err Error
for range 3 { for range 3 {
eb.Add(nil) eb.Add(nil)
@ -23,41 +25,31 @@ func TestBuilderAddNil(t *testing.T) {
for range 3 { for range 3 {
eb.Add(err) eb.Add(err)
} }
ExpectTrue(t, eb.Build() == nil) eb.AddRange(nil, nil, err)
ExpectTrue(t, eb.Build().NoError())
ExpectFalse(t, eb.HasError()) ExpectFalse(t, eb.HasError())
ExpectTrue(t, eb.Error() == nil)
}
func TestBuilderIs(t *testing.T) {
eb := NewBuilder("foo")
eb.Add(context.Canceled)
eb.Add(io.ErrShortBuffer)
ExpectTrue(t, eb.HasError())
ExpectError(t, io.ErrShortBuffer, eb.Error())
ExpectError(t, context.Canceled, eb.Error())
} }
func TestBuilderNested(t *testing.T) { func TestBuilderNested(t *testing.T) {
eb := NewBuilder("error occurred") eb := NewBuilder("action failed")
eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2"))) eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2"))
eb.Add(Failure("Action 2").With(Invalid("Inner", "3"))) eb.Add(New("Action 2").Withf("Inner: 3"))
got := eb.Build().String()
expected1 := (`error occurred:
- Action 1 failed:
- invalid Inner: 1
- invalid Inner: 2
- Action 2 failed:
- invalid Inner: 3`)
expected2 := (`error occurred:
- Action 1 failed:
- invalid Inner: "1"
- invalid Inner: "2"
- Action 2 failed:
- invalid Inner: "3"`)
ExpectEqualAny(t, got, []string{expected1, expected2})
}
func TestBuilderTo(t *testing.T) {
eb := NewBuilder("error occurred")
eb.Addf("abcd")
var err Error
eb.To(&err)
got := err.String()
expected := (`error occurred:
- abcd`)
got := eb.String()
expected := `action failed
Action 1
Inner: 1
Inner: 2
Action 2
Inner: 3`
ExpectEqual(t, got, expected) ExpectEqual(t, got, expected)
} }

View file

@ -1,317 +1,31 @@
package error package error
import ( type Error interface {
"encoding/json" error
"errors"
"fmt"
"strings"
)
type ( // Is is a wrapper for errors.Is when there is no sub-error.
Error = *ErrorImpl //
ErrorImpl struct { // When there are sub-errors, they will also be checked.
subject string Is(other error) bool
err error // With appends a sub-error to the error.
extras []ErrorImpl With(extra error) Error
} // Withf is a wrapper for With(fmt.Errorf(format, args...)).
ErrorJSONMarshaller struct { Withf(format string, args ...any) Error
Subject string `json:"subject"` // Subject prepends the given subject with a colon and space to the error message.
Err string `json:"error"` //
Extras []ErrorJSONMarshaller `json:"extras,omitempty"` // If there is already a subject in the error message, the subject will be
} // prepended to the existing subject with " > ".
) //
// Subject empty string is ignored.
func From(err error) Error { Subject(subject string) Error
if IsNil(err) { // Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)).
return nil Subjectf(format string, args ...any) Error
}
return &ErrorImpl{err: err}
} }
func FromJSON(data []byte) (Error, bool) { // this makes JSON marshalling work,
var j ErrorJSONMarshaller // as the builtin one doesn't.
if err := json.Unmarshal(data, &j); err != nil { type errStr string
return nil, false
} func (err errStr) Error() string {
if j.Err == "" { return string(err)
return nil, false
}
extras := make([]ErrorImpl, len(j.Extras))
for i, e := range j.Extras {
extra, ok := fromJSONObject(e)
if !ok {
return nil, false
}
extras[i] = *extra
}
return &ErrorImpl{
subject: j.Subject,
err: errors.New(j.Err),
extras: extras,
}, true
}
func TryUnwrap(err error) error {
if err == nil {
return nil
}
if unwrapped := errors.Unwrap(err); unwrapped != nil {
return unwrapped
}
return err
}
// Check is a helper function that
// convert (T, error) to (T, NestedError).
func Check[T any](obj T, err error) (T, Error) {
return obj, From(err)
}
func Join(message string, err ...Error) Error {
extras := make([]ErrorImpl, len(err))
nErr := 0
for i, e := range err {
if e == nil {
continue
}
extras[i] = *e
nErr++
}
if nErr == 0 {
return nil
}
return &ErrorImpl{
err: errors.New(message),
extras: extras,
}
}
func JoinE(message string, err ...error) Error {
b := NewBuilder("%s", message)
for _, e := range err {
b.AddE(e)
}
return b.Build()
}
func IsNil(err error) bool {
return err == nil
}
func IsNotNil(err error) bool {
return err != nil
}
func (ne Error) String() string {
var buf strings.Builder
ne.writeToSB(&buf, 0, "")
return buf.String()
}
func (ne Error) Is(err error) bool {
if ne == nil {
return err == nil
}
// return errors.Is(ne.err, err)
if errors.Is(ne.err, err) {
return true
}
for _, e := range ne.extras {
if e.Is(err) {
return true
}
}
return false
}
func (ne Error) IsNot(err error) bool {
return !ne.Is(err)
}
func (ne Error) Error() error {
if ne == nil {
return nil
}
return ne.buildError(0, "")
}
func (ne Error) With(s any) Error {
if ne == nil {
return ne
}
var msg string
switch ss := s.(type) {
case nil:
return ne
case *ErrorImpl:
if len(ss.extras) == 1 {
ne.extras = append(ne.extras, ss.extras[0])
return ne
}
return ne.withError(ss)
case error:
// unwrap only once
return ne.withError(From(TryUnwrap(ss)))
case string:
msg = ss
case fmt.Stringer:
return ne.appendMsg(ss.String())
default:
return ne.appendMsg(fmt.Sprint(s))
}
return ne.withError(From(errors.New(msg)))
}
func (ne Error) Extraf(format string, args ...any) Error {
return ne.With(errorf(format, args...))
}
func (ne Error) Subject(s any, sep ...string) Error {
if ne == nil {
return ne
}
var subject string
switch ss := s.(type) {
case string:
subject = ss
case fmt.Stringer:
subject = ss.String()
default:
subject = fmt.Sprint(s)
}
switch {
case ne.subject == "":
ne.subject = subject
case len(sep) > 0:
ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject)
default:
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}
return ne
}
func (ne Error) Subjectf(format string, args ...any) Error {
if ne == nil {
return ne
}
return ne.Subject(fmt.Sprintf(format, args...))
}
func (ne Error) JSONObject() ErrorJSONMarshaller {
extras := make([]ErrorJSONMarshaller, len(ne.extras))
for i, e := range ne.extras {
extras[i] = e.JSONObject()
}
return ErrorJSONMarshaller{
Subject: ne.subject,
Err: ne.err.Error(),
Extras: extras,
}
}
func (ne Error) JSON() []byte {
b, err := json.MarshalIndent(ne.JSONObject(), "", " ")
if err != nil {
panic(err)
}
return b
}
func (ne Error) NoError() bool {
return ne == nil
}
func (ne Error) HasError() bool {
return ne != nil
}
func errorf(format string, args ...any) Error {
for i, arg := range args {
if err, ok := arg.(error); ok {
if unwrapped := errors.Unwrap(err); unwrapped != nil {
args[i] = unwrapped
}
}
}
return From(fmt.Errorf(format, args...))
}
func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) {
data, err := json.Marshal(obj)
if err != nil {
return nil, false
}
return FromJSON(data)
}
func (ne Error) withError(err Error) Error {
if ne != nil && err != nil {
ne.extras = append(ne.extras, *err)
}
return ne
}
func (ne Error) appendMsg(msg string) Error {
if ne == nil {
return nil
}
ne.err = fmt.Errorf("%w %s", ne.err, msg)
return ne
}
func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) {
for range level {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return
}
if ne.subject != "" {
sb.WriteString(ne.subject)
sb.WriteRune(' ')
}
sb.WriteString(ne.err.Error())
if len(ne.extras) > 0 {
sb.WriteRune(':')
for _, extra := range ne.extras {
sb.WriteRune('\n')
extra.writeToSB(sb, level+1, "- ")
}
}
}
func (ne Error) buildError(level int, prefix string) error {
var res error
var sb strings.Builder
for range level {
sb.WriteString(" ")
}
sb.WriteString(prefix)
if ne.NoError() {
sb.WriteString("nil")
return errors.New(sb.String())
}
res = fmt.Errorf("%s%w", sb.String(), ne.err)
sb.Reset()
if ne.subject != "" {
sb.WriteString(fmt.Sprintf(" for %q", ne.subject))
}
if len(ne.extras) > 0 {
sb.WriteRune(':')
res = fmt.Errorf("%w%s", res, sb.String())
for _, extra := range ne.extras {
res = errors.Join(res, extra.buildError(level+1, "- "))
}
} else {
res = fmt.Errorf("%w%s", res, sb.String())
}
return res
} }

View file

@ -1,107 +1,157 @@
package error_test package error
import ( import (
"errors" "errors"
"strings"
"testing" "testing"
. "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestBaseString(t *testing.T) {
ExpectEqual(t, New("error").Error(), "error")
}
func TestBaseWithSubject(t *testing.T) {
err := New("error")
withSubject := err.Subject("foo")
withSubjectf := err.Subjectf("%s %s", "foo", "bar")
ExpectError(t, err, withSubject)
ExpectStrEqual(t, withSubject.Error(), "foo: error")
ExpectTrue(t, withSubject.Is(err))
ExpectError(t, err, withSubjectf)
ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error")
ExpectTrue(t, withSubjectf.Is(err))
}
func TestBaseWithExtra(t *testing.T) {
err := New("error")
extra := New("bar").Subject("baz")
withExtra := err.With(extra)
ExpectTrue(t, withExtra.Is(extra))
ExpectTrue(t, withExtra.Is(err))
ExpectTrue(t, errors.Is(withExtra, extra))
ExpectTrue(t, errors.Is(withExtra, err))
ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error()))
ExpectTrue(t, strings.Contains(withExtra.Error(), "baz"))
}
func TestBaseUnwrap(t *testing.T) {
err := errors.New("err")
wrapped := From(err)
ExpectError(t, err, errors.Unwrap(wrapped))
}
func TestNestedUnwrap(t *testing.T) {
err := errors.New("err")
err2 := New("err2")
wrapped := From(err).Subject("foo").With(err2.Subject("bar"))
unwrapper, ok := wrapped.(interface{ Unwrap() []error })
ExpectTrue(t, ok)
ExpectError(t, err, wrapped)
ExpectError(t, err2, wrapped)
ExpectEqual(t, len(unwrapper.Unwrap()), 2)
}
func TestErrorIs(t *testing.T) { func TestErrorIs(t *testing.T) {
ExpectTrue(t, Failure("foo").Is(ErrFailure)) from := errors.New("error")
ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure)) err := From(from)
ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid)) ExpectError(t, from, err)
ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid))
ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid)) ExpectTrue(t, err.Is(from))
ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure)) ExpectFalse(t, err.Is(New("error")))
ExpectFalse(t, Invalid("foo", "bar").Is(nil)) ExpectTrue(t, errors.Is(err.Subject("foo"), from))
ExpectTrue(t, errors.Is(err.Withf("foo"), from))
ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure)) ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid))
ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure))
ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists))
} }
func TestErrorNestedIs(t *testing.T) { func TestErrorImmutability(t *testing.T) {
var err Error err := New("err")
ExpectTrue(t, err.Is(nil)) err2 := New("err2")
err = Failure("some reason") for range 3 {
ExpectTrue(t, err.Is(ErrFailure)) // t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err)
ExpectFalse(t, err.Is(ErrDuplicated)) err.Subject("foo")
ExpectFalse(t, strings.Contains(err.Error(), "foo"))
err.With(Duplicated("something", "")) err.With(err2)
ExpectTrue(t, err.Is(ErrFailure)) ExpectFalse(t, strings.Contains(err.Error(), "extra"))
ExpectTrue(t, err.Is(ErrDuplicated)) ExpectFalse(t, err.Is(err2))
ExpectFalse(t, err.Is(ErrInvalid))
}
func TestIsNil(t *testing.T) { err = err.Subject("bar").Withf("baz")
var err Error ExpectTrue(t, err != nil)
ExpectTrue(t, err.Is(nil))
ExpectTrue(t, err == nil)
ExpectTrue(t, err.NoError())
eb := NewBuilder("")
returnNil := func() error {
return eb.Build().Error()
} }
ExpectTrue(t, IsNil(returnNil()))
ExpectTrue(t, returnNil() == nil)
ExpectTrue(t, (err.
Subject("any").
With("something").
Extraf("foo %s", "bar")) == nil)
}
func TestErrorSimple(t *testing.T) {
ne := Failure("foo bar")
ExpectEqual(t, ne.String(), "foo bar failed")
ne = ne.Subject("baz")
ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"")
} }
func TestErrorWith(t *testing.T) { func TestErrorWith(t *testing.T) {
ne := Failure("foo").With("bar").With("baz") err1 := New("err1")
ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz") err2 := New("err2")
err3 := err1.With(err2)
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
err2.Subject("foo")
ExpectTrue(t, err3.Is(err1))
ExpectTrue(t, err3.Is(err2))
// check if err3 is affected by err2.Subject
ExpectFalse(t, strings.Contains(err3.Error(), "foo"))
} }
func TestErrorNested(t *testing.T) { func TestErrorStringSimple(t *testing.T) {
inner := Failure("inner"). errFailure := New("generic failure")
With("1"). ne := errFailure.Subject("foo bar")
With("1") ExpectStrEqual(t, ne.Error(), "foo bar: generic failure")
inner2 := Failure("inner2"). ne = ne.Subject("baz")
ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure")
}
func TestErrorStringNested(t *testing.T) {
errFailure := New("generic failure")
inner := errFailure.Subject("inner").
Withf("1").
Withf("1")
inner2 := errFailure.Subject("inner2").
Subject("action 2"). Subject("action 2").
With("2"). Withf("2").
With("2") Withf("2")
inner3 := Failure("inner3"). inner3 := errFailure.Subject("inner3").
Subject("action 3"). Subject("action 3").
With("3"). Withf("3").
With("3") Withf("3")
ne := Failure("foo"). ne := errFailure.
With("bar"). Subject("foo").
With("baz"). Withf("bar").
Withf("baz").
With(inner). With(inner).
With(inner.With(inner2.With(inner3))) With(inner.With(inner2.With(inner3)))
want := `foo failed: want := `foo: generic failure
- bar bar
- baz baz
- inner failed: inner: generic failure
- 1 1
- 1 1
- inner failed: inner: generic failure
- 1 1
- 1 1
- inner2 failed for "action 2": action 2 > inner2: generic failure
- 2 2
- 2 2
- inner3 failed for "action 3": action 3 > inner3: generic failure
- 3 3
- 3` 3`
ExpectEqual(t, ne.String(), want) ExpectStrEqual(t, ne.Error(), want)
ExpectEqual(t, ne.Error().Error(), want)
} }

View file

@ -1,83 +0,0 @@
package error
import (
stderrors "errors"
"fmt"
"reflect"
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
ErrPanicRecv = stderrors.New("panic recovered from")
)
const fmtSubjectWhat = "%w %v: %q"
func Failure(what string) Error {
return errorf("%s %w", what, ErrFailure)
}
func FailedWhy(what string, why string) Error {
return Failure(what).With(why)
}
func FailWith(what string, err any) Error {
return Failure(what).With(err)
}
func Invalid(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrInvalid, subject, what)
}
func Unsupported(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnsupported, subject, what)
}
func Unexpected(subject, what any) Error {
return errorf(fmtSubjectWhat, ErrUnexpected, subject, what)
}
func UnexpectedError(err error) Error {
return errorf("%w error: %w", ErrUnexpected, err)
}
func NotExist(subject, what any) Error {
return errorf("%v %w: %v", subject, ErrNotExists, what)
}
func Missing(subject any) Error {
return errorf("%w %v", ErrMissing, subject)
}
func Duplicated(subject, what any) Error {
return errorf("%w %v: %v", ErrDuplicated, subject, what)
}
func OutOfRange(subject any, value any) Error {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}
func TypeError(subject any, from, to reflect.Type) Error {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}
func TypeError2(subject any, from, to reflect.Value) Error {
return TypeError(subject, from.Type(), to.Type())
}
func TypeMismatch[Expect any](value any) Error {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}
func PanicRecv(format string, args ...any) Error {
return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...))
}

43
internal/error/log.go Normal file
View file

@ -0,0 +1,43 @@
package error
import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
)
func getLogger(logger ...*zerolog.Logger) *zerolog.Logger {
if len(logger) > 0 {
return logger[0]
}
return logging.GetLogger()
}
//go:inline
func LogFatal(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Fatal().Msg(err.Error())
}
//go:inline
func LogError(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Error().Msg(err.Error())
}
//go:inline
func LogWarn(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Warn().Msg(err.Error())
}
//go:inline
func LogPanic(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Panic().Msg(err.Error())
}
//go:inline
func LogInfo(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Info().Msg(err.Error())
}
//go:inline
func LogDebug(msg string, err error, logger ...*zerolog.Logger) {
getLogger(logger...).Debug().Msg(err.Error())
}

View file

@ -0,0 +1,120 @@
package error
import (
"errors"
"fmt"
"strings"
)
type nestedError struct {
Err error `json:"err"`
Extras []error `json:"extras"`
}
func (err nestedError) Subject(subject string) Error {
if err.Err == nil {
err.Err = newError(subject)
} else {
err.Err = PrependSubject(subject, err.Err)
}
return &err
}
func (err *nestedError) Subjectf(format string, args ...any) Error {
if len(args) > 0 {
return err.Subject(fmt.Sprintf(format, args...))
}
return err.Subject(format)
}
func (err nestedError) With(extra error) Error {
if extra != nil {
err.Extras = append(err.Extras, extra)
}
return &err
}
func (err nestedError) Withf(format string, args ...any) Error {
if len(args) > 0 {
err.Extras = append(err.Extras, fmt.Errorf(format, args...))
} else {
err.Extras = append(err.Extras, newError(format))
}
return &err
}
func (err *nestedError) Unwrap() []error {
if err.Err == nil {
if len(err.Extras) == 0 {
return nil
}
return err.Extras
}
return append([]error{err.Err}, err.Extras...)
}
func (err *nestedError) Is(other error) bool {
if errors.Is(err.Err, other) {
return true
}
for _, e := range err.Extras {
if errors.Is(e, other) {
return true
}
}
return false
}
func (err *nestedError) Error() string {
return buildError(err, 0)
}
//go:inline
func makeLine(err string, level int) string {
const bulletPrefix = "• "
const spaces = " "
if level == 0 {
return err
}
return spaces[:2*level] + bulletPrefix + err
}
func makeLines(errs []error, level int) []string {
if len(errs) == 0 {
return nil
}
lines := make([]string, 0, len(errs))
for _, err := range errs {
switch err := err.(type) {
case *nestedError:
if err.Err != nil {
lines = append(lines, makeLine(err.Err.Error(), level))
}
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
lines = append(lines, extras...)
}
default:
lines = append(lines, makeLine(err.Error(), level))
}
}
return lines
}
func buildError(err error, level int) string {
switch err := err.(type) {
case nil:
return makeLine("<nil>", level)
case *nestedError:
lines := make([]string, 0, 1+len(err.Extras))
if err.Err != nil {
lines = append(lines, makeLine(err.Err.Error(), level))
}
if extras := makeLines(err.Extras, level+1); len(extras) > 0 {
lines = append(lines, extras...)
}
return strings.Join(lines, "\n")
default:
return makeLine(err.Error(), level)
}
}

50
internal/error/subject.go Normal file
View file

@ -0,0 +1,50 @@
package error
import (
"strings"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
)
type withSubject struct {
Subject string `json:"subject"`
Err error `json:"err"`
}
const subjectSep = " > "
func highlight(subject string) string {
return ansi.HighlightRed + subject + ansi.Reset
}
func PrependSubject(subject string, err error) *withSubject {
switch err := err.(type) {
case *withSubject:
return err.Prepend(subject)
case *baseError:
return PrependSubject(subject, err.Err)
default:
return &withSubject{subject, err}
}
}
func (err withSubject) Prepend(subject string) *withSubject {
if subject != "" {
err.Subject = subject + subjectSep + err.Subject
}
return &err
}
func (err *withSubject) Is(other error) bool {
return err.Err == other
}
func (err *withSubject) Unwrap() error {
return err.Err
}
func (err *withSubject) Error() string {
subjects := strings.Split(err.Subject, subjectSep)
subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1])
return strings.Join(subjects, subjectSep) + ": " + err.Err.Error()
}

71
internal/error/utils.go Normal file
View file

@ -0,0 +1,71 @@
package error
import (
"errors"
"fmt"
)
var ErrInvalidErrorJson = errors.New("invalid error json")
func newError(message string) error {
return errStr(message)
}
func New(message string) Error {
if message == "" {
return nil
}
return &baseError{newError(message)}
}
func Errorf(format string, args ...any) Error {
return &baseError{fmt.Errorf(format, args...)}
}
func From(err error) Error {
if err == nil {
return nil
}
if err, ok := err.(Error); ok {
return err
}
return &baseError{err}
}
func Must[T any](v T, err error) T {
if err != nil {
LogPanic("must failed", err)
}
return v
}
func Join(errors ...error) Error {
n := 0
for _, err := range errors {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
errs := make([]error, 0, n)
for _, err := range errors {
if err != nil {
errs = append(errs, err)
}
}
return &nestedError{Extras: errs}
}
func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T {
result, err := fn(arg)
eb.Add(err)
return result
}
func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T {
result, err := fn(arg1, arg2)
eb.Add(err)
return result
}

View file

@ -53,7 +53,7 @@ func ListAvailableIcons() ([]string, error) {
icons = append(icons, content.Path) icons = append(icons, content.Path)
} }
} }
err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error() err = utils.SaveJSON(iconsCachePath, &icons, 0o644)
if err != nil { if err != nil {
log.Print("error saving cache", err) log.Print("error saving cache", err)
} }

View file

@ -0,0 +1,69 @@
package logging
import (
"os"
"strings"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
)
var logger zerolog.Logger
func init() {
var timeFmt string
var level zerolog.Level
var exclude []string
if common.IsTrace {
timeFmt = "04:05"
level = zerolog.TraceLevel
} else if common.IsDebug {
timeFmt = "01-02 15:04"
level = zerolog.DebugLevel
} else {
timeFmt = "01-02 15:04"
level = zerolog.InfoLevel
exclude = []string{"module"}
}
prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces
prefix := strings.Repeat(" ", prefixLength)
logger = zerolog.New(
zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: timeFmt,
FieldsExclude: exclude,
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
msg := msgI.(string)
lines := strings.Split(msg, "\n")
if len(lines) == 1 {
return msg
}
for i := 1; i < len(lines); i++ {
lines[i] = prefix + lines[i]
}
return strings.Join(lines, "\n")
},
},
).Level(level).With().Timestamp().Logger()
}
func DiscardLogger() { logger = zerolog.Nop() }
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
func GetLogger() *zerolog.Logger { return &logger }
func With() zerolog.Context { return logger.With() }
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
func Info() *zerolog.Event { return logger.Info() }
func Warn() *zerolog.Event { return logger.Warn() }
func Error() *zerolog.Event { return logger.Error() }
func Err(err error) *zerolog.Event { return logger.Err(err) }
func Debug() *zerolog.Event { return logger.Debug() }
func Fatal() *zerolog.Event { return logger.Fatal() }
func Panic() *zerolog.Event { return logger.Panic() }
func Trace() *zerolog.Event { return logger.Trace() }

View file

@ -0,0 +1,15 @@
package http
import "net/http"
type DummyResponseWriter struct{}
func (w DummyResponseWriter) Header() http.Header {
return make(http.Header)
}
func (w DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w DummyResponseWriter) WriteHeader(int) {}

View file

@ -1,15 +0,0 @@
package loadbalancer
import "net/http"
type DummyResponseWriter struct{}
func (w *DummyResponseWriter) Header() (_ http.Header) {
return
}
func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) {
return
}
func (w *DummyResponseWriter) WriteHeader(int) {}

View file

@ -11,20 +11,22 @@ import (
) )
type ipHash struct { type ipHash struct {
*LoadBalancer
realIP *middleware.Middleware realIP *middleware.Middleware
pool servers pool servers
mu sync.Mutex mu sync.Mutex
} }
func (lb *LoadBalancer) newIPHash() impl { func (lb *LoadBalancer) newIPHash() impl {
impl := new(ipHash) impl := &ipHash{LoadBalancer: lb}
if len(lb.Options) == 0 { if len(lb.Options) == 0 {
return impl return impl
} }
var err E.Error var err E.Error
impl.realIP, err = middleware.NewRealIP(lb.Options) impl.realIP, err = middleware.NewRealIP(lb.Options)
if err != nil { if err != nil {
logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err) E.LogError("invalid real_ip options, ignoring", err, &impl.Logger)
} }
return impl return impl
} }
@ -70,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr) ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil { if err != nil {
http.Error(rw, "Internal error", http.StatusInternalServerError) http.Error(rw, "Internal error", http.StatusInternalServerError)
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err) impl.Err(err).Msg("invalid remote address " + r.RemoteAddr)
return return
} }
idx := hashIP(ip) % uint32(len(impl.pool)) idx := hashIP(ip) % uint32(len(impl.pool))

View file

@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.R
srv := srvs[0] srv := srvs[0]
minConn, ok := impl.nConn.Load(srv) minConn, ok := impl.nConn.Load(srv)
if !ok { if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name) impl.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError) http.Error(rw, "Internal error", http.StatusInternalServerError)
} }
for i := 1; i < len(srvs); i++ { for i := 1; i < len(srvs); i++ {
nConn, ok := impl.nConn.Load(srvs[i]) nConn, ok := impl.nConn.Load(srvs[i])
if !ok { if !ok {
logger.Errorf("[BUG] server %s not found", srv.Name) impl.Error().Msgf("[BUG] server %s not found", srv.Name)
http.Error(rw, "Internal error", http.StatusInternalServerError) http.Error(rw, "Internal error", http.StatusInternalServerError)
} }
if nConn.Load() < minConn.Load() { if nConn.Load() < minConn.Load() {

View file

@ -6,9 +6,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
@ -29,6 +31,8 @@ type (
Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"` Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"`
} }
LoadBalancer struct { LoadBalancer struct {
zerolog.Logger
impl impl
*Config *Config
@ -48,6 +52,7 @@ const maxWeight weightType = 100
func New(cfg *Config) *LoadBalancer { func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{ lb := &LoadBalancer{
Logger: logger.With().Str("name", cfg.Link).Logger(),
Config: new(Config), Config: new(Config),
pool: newPool(), pool: newPool(),
task: task.DummyTask(), task: task.DummyTask(),
@ -102,7 +107,7 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) {
if lb.Mode == Unset && cfg.Mode != Unset { if lb.Mode == Unset && cfg.Mode != Unset {
lb.Mode = cfg.Mode lb.Mode = cfg.Mode
if !lb.Mode.ValidateUpdate() { if !lb.Mode.ValidateUpdate() {
logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode) lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode)
} }
lb.updateImpl() lb.updateImpl()
} }
@ -131,7 +136,11 @@ func (lb *LoadBalancer) AddServer(srv *Server) {
lb.rebalance() lb.rebalance()
lb.impl.OnAddServer(srv) lb.impl.OnAddServer(srv)
logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size())
lb.Debug().
Str("action", "add").
Str("server", srv.Name).
Msgf("%d servers available", lb.pool.Size())
} }
func (lb *LoadBalancer) RemoveServer(srv *Server) { func (lb *LoadBalancer) RemoveServer(srv *Server) {
@ -148,13 +157,15 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) {
lb.rebalance() lb.rebalance()
lb.impl.OnRemoveServer(srv) lb.impl.OnRemoveServer(srv)
lb.Debug().
Str("action", "remove").
Str("server", srv.Name).
Msgf("%d servers left", lb.pool.Size())
if lb.pool.Size() == 0 { if lb.pool.Size() == 0 {
lb.task.Finish("no server left") lb.task.Finish("no server left")
logger.Infof("loadbalancer %s stopped", lb.Link)
return return
} }
logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size())
} }
func (lb *LoadBalancer) rebalance() { func (lb *LoadBalancer) rebalance() {
@ -218,7 +229,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
defer cancel() defer cancel()
// send dummy request to wake all servers // send dummy request to wake all servers
var dummyRW *DummyResponseWriter var dummyRW gphttp.DummyResponseWriter
for _, srv := range srvs { for _, srv := range srvs {
// wake only if server implements Waker // wake only if server implements Waker
_, ok := srv.handler.(idlewatcher.Waker) _, ok := srv.handler.(idlewatcher.Waker)

View file

@ -1,5 +1,5 @@
package loadbalancer package loadbalancer
import "github.com/sirupsen/logrus" import "github.com/yusing/go-proxy/internal/logging"
var logger = logrus.WithField("module", "load_balancer") var logger = logging.With().Str("module", "load_balancer").Logger()

View file

@ -0,0 +1,5 @@
package http
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "http").Logger()

View file

@ -15,9 +15,9 @@ type cidrWhitelist struct {
} }
type cidrWhitelistOpts struct { type cidrWhitelistOpts struct {
Allow []*types.CIDR Allow []*types.CIDR `json:"allow"`
StatusCode int StatusCode int `json:"statusCode"`
Message string Message string `json:"message"`
cachedAddr F.Map[string, bool] // cache for trusted IPs cachedAddr F.Map[string, bool] // cache for trusted IPs
} }
@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) {
return nil, err return nil, err
} }
if len(wl.cidrWhitelistOpts.Allow) == 0 { if len(wl.cidrWhitelistOpts.Allow) == 0 {
return nil, E.Missing("allow range") return nil, E.New("no allowed CIDRs")
} }
return wl.m, nil return wl.m, nil
} }

View file

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte
var deny, accept *Middleware var deny, accept *Middleware
func TestCIDRWhitelist(t *testing.T) { func TestCIDRWhitelist(t *testing.T) {
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose) errs := E.NewBuilder("")
if err != nil { mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs)
panic(err) ExpectNoError(t, errs.Error())
}
deny = mids["deny@file"] deny = mids["deny@file"]
accept = mids["accept@file"] accept = mids["accept@file"]
if deny == nil || accept == nil { if deny == nil || accept == nil {
@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("deny", func(t *testing.T) { t.Run("deny", func(t *testing.T) {
for range 10 { for range 10 {
result, err := newMiddlewareTest(deny, nil) result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode) ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message) ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
} }
@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) {
t.Run("accept", func(t *testing.T) { t.Run("accept", func(t *testing.T) {
for range 10 { for range 10 {
result, err := newMiddlewareTest(accept, nil) result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)
} }
}) })

View file

@ -10,10 +10,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
const ( const (
@ -26,7 +26,7 @@ const (
var ( var (
cfCIDRsLastUpdate time.Time cfCIDRsLastUpdate time.Time
cfCIDRsMu sync.Mutex cfCIDRsMu sync.Mutex
cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP") cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger()
) )
var CloudflareRealIP = &realIP{ var CloudflareRealIP = &realIP{
@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
) )
if err != nil { if err != nil {
cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)
cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval) cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil return nil
} }
} }
cfCIDRsLastUpdate = time.Now() cfCIDRsLastUpdate = time.Now()
cfCIDRsLogger.Debugf("cloudflare CIDR range updated") cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated")
return return
} }

View file

@ -8,24 +8,31 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
) )
var CustomErrorPage = &Middleware{ var CustomErrorPage *Middleware
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) { func init() {
next(w, r) CustomErrorPage = customErrorPage()
} }
},
modifyResponse: func(resp *Response) error { func customErrorPage() *Middleware {
m := &Middleware{
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next(w, r)
}
},
}
m.modifyResponse = func(resp *Response) error {
// only handles non-success status code and html/plain content type // only handles non-success status code and html/plain content type
contentType := gphttp.GetContentType(resp.Header) contentType := gphttp.GetContentType(resp.Header)
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok { if ok {
errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode) CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
/* trunk-ignore(golangci-lint/errcheck) */ /* trunk-ignore(golangci-lint/errcheck) */
io.Copy(io.Discard, resp.Body) // drain the original body io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close() resp.Body.Close()
@ -34,12 +41,13 @@ var CustomErrorPage = &Middleware{
resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage))) resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage)))
resp.Header.Set("Content-Type", "text/html; charset=utf-8") resp.Header.Set("Content-Type", "text/html; charset=utf-8")
} else { } else {
errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode) CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
} }
return nil return nil
} }
return nil return nil
}, }
return m
} }
func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
@ -51,7 +59,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
filename := path[len(gphttp.StaticFilePathPrefix):] filename := path[len(gphttp.StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename) file, ok := errorpage.GetStaticFile(filename)
if !ok { if !ok {
errPageLogger.Errorf("unable to load resource %s", filename) logger.Error().Msg("unable to load resource " + filename)
return false return false
} }
ext := filepath.Ext(filename) ext := filepath.Ext(filename)
@ -63,15 +71,13 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool {
case ".css": case ".css":
w.Header().Set("Content-Type", "text/css; charset=utf-8") w.Header().Set("Content-Type", "text/css; charset=utf-8")
default: default:
errPageLogger.Errorf("unexpected file type %q for %s", ext, filename) logger.Error().Msgf("unexpected file type %q for %s", ext, filename)
} }
if _, err := w.Write(file); err != nil { if _, err := w.Write(file); err != nil {
errPageLogger.WithError(err).Errorf("unable to write resource %s", filename) logger.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError) http.Error(w, "Error page failure", http.StatusInternalServerError)
} }
return true return true
} }
return false return false
} }
var errPageLogger = logrus.WithField("middleware", "error_page")

View file

@ -1,14 +1,15 @@
package errorpage package errorpage
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path" "path"
"sync" "sync"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
W "github.com/yusing/go-proxy/internal/watcher" W "github.com/yusing/go-proxy/internal/watcher"
@ -23,9 +24,10 @@ var (
) )
var setup = sync.OnceFunc(func() { var setup = sync.OnceFunc(func() {
dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath) task := task.GlobalTask("error page")
dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath)
loadContent() loadContent()
go watchDir() go watchDir(task)
}) })
func GetStaticFile(filename string) ([]byte, bool) { func GetStaticFile(filename string) ([]byte, bool) {
@ -44,7 +46,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
func loadContent() { func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0) files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil { if err != nil {
Logger.Error(err) logger.Err(err).Msg("failed to list error page resources")
return return
} }
for _, file := range files { for _, file := range files {
@ -53,19 +55,21 @@ func loadContent() {
} }
content, err := os.ReadFile(file) content, err := os.ReadFile(file)
if err != nil { if err != nil {
Logger.Errorf("failed to read error page resource %s: %s", file, err) logger.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue continue
} }
file = path.Base(file) file = path.Base(file)
Logger.Infof("error page resource %s loaded", file) logging.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content) fileContentMap.Store(file, content)
} }
} }
func watchDir() { func watchDir(task task.Task) {
eventCh, errCh := dirWatcher.Events(context.Background()) eventCh, errCh := dirWatcher.Events(task.Context())
for { for {
select { select {
case <-task.Context().Done():
return
case event, ok := <-eventCh: case event, ok := <-eventCh:
if !ok { if !ok {
return return
@ -77,14 +81,14 @@ func watchDir() {
loadContent() loadContent()
case events.ActionFileDeleted: case events.ActionFileDeleted:
fileContentMap.Delete(filename) fileContentMap.Delete(filename)
Logger.Infof("error page resource %s deleted", filename) logger.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed: case events.ActionFileRenamed:
Logger.Infof("error page resource %s deleted", filename) logger.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename) fileContentMap.Delete(filename)
loadContent() loadContent()
} }
case err := <-errCh: case err := <-errCh:
Logger.Errorf("error watching error page directory: %s", err) E.LogError("error watching error page directory", err, &logger)
} }
} }
} }

View file

@ -0,0 +1,5 @@
package errorpage
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "errorpage").Logger()

View file

@ -24,11 +24,12 @@ type (
client http.Client client http.Client
} }
forwardAuthOpts struct { forwardAuthOpts struct {
Address string Address string `json:"address"`
TrustForwardHeader bool TrustForwardHeader bool `json:"trustForwardHeader"`
AuthResponseHeaders []string AuthResponseHeaders []string `json:"authResponseHeaders"`
AddAuthCookiesToResponse []string AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
transport http.RoundTripper
transport http.RoundTripper
} }
) )
@ -39,13 +40,11 @@ var ForwardAuth = &forwardAuth{
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth) fa := new(forwardAuth)
fa.forwardAuthOpts = new(forwardAuthOpts) fa.forwardAuthOpts = new(forwardAuthOpts)
err := Deserialize(optsRaw, fa.forwardAuthOpts) if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil {
if err != nil {
return nil, err return nil, err
} }
_, err = E.Check(url.Parse(fa.Address)) if _, err := url.Parse(fa.Address); err != nil {
if err != nil { return nil, E.From(err)
return nil, E.Invalid("address", fa.Address)
} }
fa.m = &Middleware{ fa.m = &Middleware{

View file

@ -0,0 +1,5 @@
package middleware
import "github.com/yusing/go-proxy/internal/logging"
var logger = logging.With().Str("module", "middleware").Logger()

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
@ -32,6 +33,8 @@ type (
Middleware struct { Middleware struct {
_ U.NoCopy _ U.NoCopy
zerolog.Logger
name string name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP before BeforeFunc // runs before ReverseProxy.ServeHTTP
@ -78,13 +81,19 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
} }
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) { func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if len(optsRaw) != 0 && m.withOptions != nil { if m.withOptions != nil {
return m.withOptions(optsRaw) m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
} }
// WithOptionsClone is called only once // WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that // set withOptions and labelParser will not be used after that
return &Middleware{ return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name, name: m.name,
before: m.before, before: m.before,
modifyResponse: m.modifyResponse, modifyResponse: m.modifyResponse,
@ -108,24 +117,20 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
} }
// TODO: check conflict or duplicates. // TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) { func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares = make([]*Middleware, 0, len(middlewaresMap)) middlewares := make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares") errs := E.NewBuilder("middlewares compile error")
invalidOpts := E.NewBuilder("invalid options") invalidOpts := E.NewBuilder("options compile error")
defer func() {
invalidM.Add(invalidOpts.Build())
invalidM.To(&res)
}()
for name, opts := range middlewaresMap { for name, opts := range middlewaresMap {
m, ok := Get(name) m, err := Get(name)
if !ok { if err != nil {
invalidM.Add(E.NotExist("middleware", name)) errs.Add(err)
continue continue
} }
m, err := m.WithOptionsClone(opts) m, err = m.WithOptionsClone(opts)
if err != nil { if err != nil {
invalidOpts.Add(err.Subject(name)) invalidOpts.Add(err.Subject(name))
continue continue
@ -133,7 +138,10 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid
middlewares = append(middlewares, m) middlewares = append(middlewares, m)
} }
return if invalidOpts.HasError() {
errs.Add(invalidOpts.Error())
}
return middlewares, errs.Error()
} }
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {

View file

@ -4,64 +4,60 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"path"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) { func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware {
fileContent, err := os.ReadFile(filePath) fileContent, err := os.ReadFile(filePath)
if err != nil { if err != nil {
return nil, E.FailWith("read middleware compose file", err) eb.Add(err)
return nil
} }
return BuildMiddlewaresFromYAML(fileContent) return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
} }
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) { func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)
var rawMap map[string][]map[string]any var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap) err := yaml.Unmarshal(data, &rawMap)
if err != nil { if err != nil {
b.Add(E.FailWith("yaml unmarshal", err)) eb.Add(err)
return return nil
} }
middlewares = make(map[string]*Middleware) middlewares := make(map[string]*Middleware)
for name, defs := range rawMap { for name, defs := range rawMap {
chainErr := E.NewBuilder("%s", name) chainErr := E.NewBuilder("")
chain := make([]*Middleware, 0, len(defs)) chain := make([]*Middleware, 0, len(defs))
for i, def := range defs { for i, def := range defs {
if def["use"] == nil || def["use"] == "" { if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i)) chainErr.Addf("item %d: missing field 'use'", i)
continue continue
} }
baseName := def["use"].(string) baseName := def["use"].(string)
base, ok := Get(baseName) base, err := Get(baseName)
if !ok { if err != nil {
base, ok = middlewares[baseName] chainErr.Add(err.Subjectf("%s[%d]", name, i))
if !ok { continue
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
} }
delete(def, "use") delete(def, "use")
m, err := base.WithOptionsClone(def) m, err := base.WithOptionsClone(def)
if err != nil { if err != nil {
chainErr.Add(err.Subjectf("item%d", i)) chainErr.Add(err.Subjectf("%s[%d]", name, i))
continue continue
} }
m.name = fmt.Sprintf("%s[%d]", name, i) m.name = fmt.Sprintf("%s[%d]", name, i)
chain = append(chain, m) chain = append(chain, m)
} }
if chainErr.HasError() { if chainErr.HasError() {
b.Add(chainErr.Build()) eb.Add(chainErr.Error().Subject(source))
} else { } else {
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain) middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
} }
} }
return return middlewares
} }
// TODO: check conflict or duplicates. // TODO: check conflict or duplicates.
@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
} }
if len(modResps) > 0 { if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error { m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware") errs := E.NewBuilder("modify response errors")
for _, mr := range modResps { for _, mr := range modResps {
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name)) if err := mr.modifyResponse(res); err != nil {
errs.Add(E.From(err).Subject(mr.name))
}
} }
return b.Build().Error() return errs.Error()
} }
} }

View file

@ -13,10 +13,10 @@ import (
var testMiddlewareCompose []byte var testMiddlewareCompose []byte
func TestBuild(t *testing.T) { func TestBuild(t *testing.T) {
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) errs := E.NewBuilder("")
ExpectNoError(t, err.Error()) middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs)
_, err = E.Check(json.MarshalIndent(middlewares, "", " ")) ExpectNoError(t, errs.Error())
ExpectNoError(t, err.Error()) E.Must(json.MarshalIndent(middlewares, "", " "))
// t.Log(string(data)) // t.Log(string(data))
// TODO: test // TODO: test
} }

View file

@ -6,26 +6,37 @@ import (
"path" "path"
"strings" "strings"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
var middlewares map[string]*Middleware var allMiddlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) { var (
middleware, ok = middlewares[U.ToLowerNoSnake(name)] ErrUnknownMiddleware = E.New("unknown middleware")
return ErrDuplicatedMiddleware = E.New("duplicated middleware")
)
func Get(name string) (*Middleware, Error) {
middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)]
if !ok {
return nil, ErrUnknownMiddleware.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares)))
}
return middleware, nil
} }
func All() map[string]*Middleware { func All() map[string]*Middleware {
return middlewares return allMiddlewares
} }
// initialize middleware names and label parsers // initialize middleware names and label parsers
func init() { func init() {
middlewares = map[string]*Middleware{ allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded, "setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded, "hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP, "redirecthttp": RedirectHTTP,
@ -39,10 +50,10 @@ func init() {
// !experimental // !experimental
"forwardauth": ForwardAuth.m, "forwardauth": ForwardAuth.m,
"oauth2": OAuth2.m, // "oauth2": OAuth2.m,
} }
names := make(map[*Middleware][]string) names := make(map[*Middleware][]string)
for name, m := range middlewares { for name, m := range allMiddlewares {
names[m] = append(names[m], http.CanonicalHeaderKey(name)) names[m] = append(names[m], http.CanonicalHeaderKey(name))
} }
for m, names := range names { for m, names := range names {
@ -55,27 +66,30 @@ func init() {
} }
func LoadComposeFiles() { func LoadComposeFiles() {
b := E.NewBuilder("failed to load middlewares") errs := E.NewBuilder("middleware compile errors")
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0) middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil { if err != nil {
logrus.Errorf("failed to list middleware definitions: %s", err) logger.Err(err).Msg("failed to list middleware definitions")
return return
} }
for _, defFile := range middlewareDefs { for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromComposeFile(defFile) mws := BuildMiddlewaresFromComposeFile(defFile, errs)
if len(mws) == 0 {
continue
}
for name, m := range mws { for name, m := range mws {
if _, ok := middlewares[name]; ok { if _, ok := allMiddlewares[name]; ok {
b.Add(E.Duplicated("middleware", name)) errs.Add(ErrDuplicatedMiddleware.Subject(name))
continue continue
} }
middlewares[U.ToLowerNoSnake(name)] = m allMiddlewares[U.ToLowerNoSnake(name)] = m
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile)) logger.Info().
Str("name", name).
Str("src", path.Base(defFile)).
Msg("middleware loaded")
} }
b.Add(err.Subject(path.Base(defFile)))
} }
if b.HasError() { if errs.HasError() {
logger.Error(b.Build()) E.LogError(errs.About(), errs.Error(), &logger)
} }
} }
var logger = logrus.WithField("module", "middlewares")

View file

@ -12,9 +12,9 @@ type (
} }
// order: set_headers -> add_headers -> hide_headers // order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct { modifyRequestOpts struct {
SetHeaders map[string]string SetHeaders map[string]string `json:"setHeaders"`
AddHeaders map[string]string AddHeaders map[string]string `json:"addHeaders"`
HideHeaders []string HideHeaders []string `json:"hideHeaders"`
} }
) )

View file

@ -16,7 +16,7 @@ func TestSetModifyRequest(t *testing.T) {
t.Run("set_options", func(t *testing.T) { t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts) mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string))
@ -26,7 +26,7 @@ func TestSetModifyRequest(t *testing.T) {
result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
middlewareOpt: opts, middlewareOpt: opts,
}) })
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")

View file

@ -13,11 +13,7 @@ type (
m *Middleware m *Middleware
} }
// order: set_headers -> add_headers -> hide_headers // order: set_headers -> add_headers -> hide_headers
modifyResponseOpts struct { modifyResponseOpts = modifyRequestOpts
SetHeaders map[string]string
AddHeaders map[string]string
HideHeaders []string
}
) )
var ModifyResponse = &modifyResponse{ var ModifyResponse = &modifyResponse{

View file

@ -16,7 +16,7 @@ func TestSetModifyResponse(t *testing.T) {
t.Run("set_options", func(t *testing.T) { t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts) mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
@ -26,7 +26,7 @@ func TestSetModifyResponse(t *testing.T) {
result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{ result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{
middlewareOpt: opts, middlewareOpt: opts,
}) })
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0") ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding")) t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))

View file

@ -1,129 +1,117 @@
package middleware package middleware
import ( // import (
"encoding/json" // "encoding/json"
"fmt" // "fmt"
"net/http" // "net/http"
"net/url" // "net/url"
"reflect"
E "github.com/yusing/go-proxy/internal/error" // E "github.com/yusing/go-proxy/internal/error"
) // )
type oAuth2 struct { // type oAuth2 struct {
*oAuth2Opts // oAuth2Opts
m *Middleware // m *Middleware
} // }
type oAuth2Opts struct { // type oAuth2Opts struct {
ClientID string // ClientID string `validate:"required"`
ClientSecret string // ClientSecret string `validate:"required"`
AuthURL string // Authorization Endpoint // AuthURL string `validate:"required"` // Authorization Endpoint
TokenURL string // Token Endpoint // TokenURL string `validate:"required"` // Token Endpoint
} // }
var OAuth2 = &oAuth2{ // var OAuth2 = &oAuth2{
m: &Middleware{withOptions: NewAuthentikOAuth2}, // m: &Middleware{withOptions: NewAuthentikOAuth2},
} // }
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { // func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) {
oauth := new(oAuth2) // oauth := new(oAuth2)
oauth.m = &Middleware{ // oauth.m = &Middleware{
impl: oauth, // impl: oauth,
before: oauth.handleOAuth2, // before: oauth.handleOAuth2,
} // }
oauth.oAuth2Opts = &oAuth2Opts{} // err := Deserialize(opts, &oauth.oAuth2Opts)
err := Deserialize(opts, oauth.oAuth2Opts) // if err != nil {
if err != nil { // return nil, err
return nil, err // }
} // return oauth.m, nil
b := E.NewBuilder("missing required fields") // }
optV := reflect.ValueOf(oauth.oAuth2Opts)
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
if optV.FieldByName(field.Name).Len() == 0 {
b.Add(E.Missing(field.Name))
}
}
if b.HasError() {
return nil, b.Build().Subject("oAuth2")
}
return oauth.m, nil
}
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { // func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
// Check if the user is authenticated (you may use session, cookie, etc.) // // Check if the user is authenticated (you may use session, cookie, etc.)
if !userIsAuthenticated(r) { // if !userIsAuthenticated(r) {
// TODO: Redirect to OAuth2 auth URL // // TODO: Redirect to OAuth2 auth URL
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", // http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) // oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
return // return
} // }
// If you have a token in the query string, process it // // If you have a token in the query string, process it
if code := r.URL.Query().Get("code"); code != "" { // if code := r.URL.Query().Get("code"); code != "" {
// Exchange the authorization code for a token here // // Exchange the authorization code for a token here
// Use the TokenURL and authenticate the user // // Use the TokenURL and authenticate the user
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI) // token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI)
if err != nil { // if err != nil {
// handle error // // handle error
http.Error(rw, "failed to get token", http.StatusUnauthorized) // http.Error(rw, "failed to get token", http.StatusUnauthorized)
return // return
} // }
// Save token and user info based on your requirements // // Save token and user info based on your requirements
saveToken(rw, token) // saveToken(rw, token)
// Redirect to the originally requested URL // // Redirect to the originally requested URL
http.Redirect(rw, r, "/", http.StatusFound) // http.Redirect(rw, r, "/", http.StatusFound)
return // return
} // }
// If user is authenticated, go to the next handler // // If user is authenticated, go to the next handler
next(rw, r) // next(rw, r)
} // }
func userIsAuthenticated(r *http.Request) bool { // func userIsAuthenticated(r *http.Request) bool {
// Example: Check for a session or cookie // // Example: Check for a session or cookie
session, err := r.Cookie("session_token") // session, err := r.Cookie("session_token")
if err != nil || session.Value == "" { // if err != nil || session.Value == "" {
return false // return false
} // }
// Validate the session_token if necessary // // Validate the session_token if necessary
return true // return true
} // }
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { // func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) {
// Prepare the request body // // Prepare the request body
data := url.Values{ // data := url.Values{
"client_id": {opts.ClientID}, // "client_id": {opts.ClientID},
"client_secret": {opts.ClientSecret}, // "client_secret": {opts.ClientSecret},
"code": {code}, // "code": {code},
"grant_type": {"authorization_code"}, // "grant_type": {"authorization_code"},
"redirect_uri": {requestURI}, // "redirect_uri": {requestURI},
} // }
resp, err := http.PostForm(opts.TokenURL, data) // resp, err := http.PostForm(opts.TokenURL, data)
if err != nil { // if err != nil {
return "", fmt.Errorf("failed to request token: %w", err) // return "", fmt.Errorf("failed to request token: %w", err)
} // }
defer resp.Body.Close() // defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { // if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) // return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
} // }
// Decode the response // // Decode the response
var tokenResp struct { // var tokenResp struct {
AccessToken string `json:"access_token"` // AccessToken string `json:"access_token"`
} // }
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { // if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err) // return "", fmt.Errorf("failed to decode token response: %w", err)
} // }
return tokenResp.AccessToken, nil // return tokenResp.AccessToken, nil
} // }
func saveToken(rw ResponseWriter, token string) { // func saveToken(rw ResponseWriter, token string) {
// Example: Save token in cookie // // Example: Save token in cookie
http.SetCookie(rw, &http.Cookie{ // http.SetCookie(rw, &http.Cookie{
Name: "auth_token", // Name: "auth_token",
Value: token, // Value: token,
// set other properties as necessary, such as Secure and HttpOnly // // set other properties as necessary, such as Secure and HttpOnly
}) // })
} // }

View file

@ -16,9 +16,9 @@ type realIP struct {
type realIPOpts struct { type realIPOpts struct {
// Header is the name of the header to use for the real client IP // Header is the name of the header to use for the real client IP
Header string Header string `json:"header"`
// From is a list of Address / CIDRs to trust // From is a list of Address / CIDRs to trust
From []*types.CIDR From []*types.CIDR `json:"from"`
/* /*
If recursive search is disabled, If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
@ -27,7 +27,7 @@ type realIPOpts struct {
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
the last non-trusted address sent in the request header field. the last non-trusted address sent in the request header field.
*/ */
Recursive bool Recursive bool `json:"recursive"`
} }
var RealIP = &realIP{ var RealIP = &realIP{

View file

@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) {
} }
ri, err := NewRealIP(opts) ri, err := NewRealIP(opts)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From { for i, CIDR := range ri.impl.(*realIP).From {
@ -61,15 +61,15 @@ func TestSetRealIP(t *testing.T) {
"set_headers": map[string]string{testHeader: testRealIP}, "set_headers": map[string]string{testHeader: testRealIP},
} }
realip, err := NewRealIP(opts) realip, err := NewRealIP(opts)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
mr, err := NewModifyRequest(optsMr) mr, err := NewModifyRequest(optsMr)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip}) mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil) result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
t.Log(traces) t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)

View file

@ -12,7 +12,7 @@ func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http", scheme: "http",
}) })
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect) ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort) ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
} }
@ -21,6 +21,6 @@ func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https", scheme: "https",
}) })
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)
} }

View file

@ -6,7 +6,7 @@ import (
"time" "time"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type Trace struct { type Trace struct {
@ -88,7 +88,7 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
return nil return nil
} }
return addTrace(&Trace{ return addTrace(&Trace{
Time: U.FormatTime(time.Now()), Time: strutils.FormatTime(time.Now()),
Caller: m.Fullname(), Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...), Message: fmt.Sprintf(msg, args...),
}) })

View file

@ -57,8 +57,7 @@ func (w *ModifyResponseWriter) WriteHeader(code int) {
} }
if err := w.modifier(&resp); err != nil { if err := w.modifier(&resp); err != nil {
w.modifierErr = err w.modifierErr = fmt.Errorf("response modifier error: %w", err)
logger.Errorf("error modifying response: %s", err)
w.w.WriteHeader(http.StatusInternalServerError) w.w.WriteHeader(http.StatusInternalServerError)
return return
} }

View file

@ -23,7 +23,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
@ -69,6 +69,8 @@ type ProxyRequest struct {
// 1xx responses are forwarded to the client if the underlying // 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse. // transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct { type ReverseProxy struct {
zerolog.Logger
// The transport used to perform proxy requests. // The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used. // If nil, http.DefaultTransport is used.
Transport http.RoundTripper Transport http.RoundTripper
@ -149,7 +151,12 @@ func NewReverseProxy(name string, target types.URL, transport http.RoundTripper)
if transport == nil { if transport == nil {
panic("nil transport") panic("nil transport")
} }
rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target} rp := &ReverseProxy{
Logger: logger.With().Str("name", name).Logger(),
Transport: transport,
TargetName: name,
TargetURL: target,
}
rp.ServeHTTP = rp.serveHTTP rp.ServeHTTP = rp.serveHTTP
return rp return rp
} }
@ -195,9 +202,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
switch { switch {
case errors.Is(err, context.Canceled), case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF): errors.Is(err, io.EOF):
logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error")
default: default:
logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error")
} }
if writeHeader { if writeHeader {
rw.WriteHeader(http.StatusBadGateway) rw.WriteHeader(http.StatusBadGateway)
@ -219,6 +226,10 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
} }
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
if _, ok := rw.(DummyResponseWriter); ok {
return
}
transport := p.Transport transport := p.Transport
ctx := req.Context() ctx := req.Context()
@ -453,6 +464,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
resUpType := UpgradeType(res.Header) resUpType := UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
return
} }
if !strings.EqualFold(reqUpType, resUpType) { if !strings.EqualFold(reqUpType, resUpType) {
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true) p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true)
@ -518,5 +530,3 @@ func IsPrint(s string) bool {
} }
return true return true
} }
var logger = logrus.WithField("module", "http")

View file

@ -9,10 +9,15 @@ import (
type CIDR net.IPNet type CIDR net.IPNet
var (
ErrInvalidCIDR = E.New("invalid CIDR")
ErrInvalidCIDRType = E.New("invalid CIDR type")
)
func (cidr *CIDR) ConvertFrom(val any) E.Error { func (cidr *CIDR) ConvertFrom(val any) E.Error {
cidrStr, ok := val.(string) cidrStr, ok := val.(string)
if !ok { if !ok {
return E.TypeMismatch[string](val) return ErrInvalidCIDRType.Subjectf("%T", val)
} }
if !strings.Contains(cidrStr, "/") { if !strings.Contains(cidrStr, "/") {
@ -20,7 +25,7 @@ func (cidr *CIDR) ConvertFrom(val any) E.Error {
} }
_, ipnet, err := net.ParseCIDR(cidrStr) _, ipnet, err := net.ParseCIDR(cidrStr)
if err != nil { if err != nil {
return E.Invalid("CIDR", cidr) return ErrInvalidCIDR.Subject(cidrStr)
} }
*cidr = CIDR(*ipnet) *cidr = CIDR(*ipnet)
return nil return nil

View file

@ -5,9 +5,28 @@ import (
"net" "net"
) )
type Stream interface { type (
fmt.Stringer Stream interface {
net.Listener fmt.Stringer
Setup() error StreamListener
Handle(conn net.Conn) error Setup() error
Handle(conn StreamConn) error
}
StreamListener interface {
Addr() net.Addr
Accept() (StreamConn, error)
Close() error
}
StreamConn any
NetListenerWrapper struct {
net.Listener
}
)
func NetListener(l net.Listener) StreamListener {
return NetListenerWrapper{Listener: l}
}
func (l NetListenerWrapper) Accept() (StreamConn, error) {
return l.Listener.Accept()
} }

View file

@ -1,22 +1,32 @@
package notif package notif
import ( import (
"github.com/sirupsen/logrus" "github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( type (
Dispatcher struct { Dispatcher struct {
task task.Task task task.Task
logCh chan *logrus.Entry logCh chan *LogMessage
providers F.Set[Provider] providers F.Set[Provider]
} }
LogMessage struct {
Level zerolog.Level
Title, Message string
}
) )
var dispatcher *Dispatcher var dispatcher *Dispatcher
var ErrUnknownNotifProvider = E.New("unknown notification provider")
const dispatchErr = "notification dispatch error"
func init() { func init() {
dispatcher = newNotifDispatcher() dispatcher = newNotifDispatcher()
go dispatcher.start() go dispatcher.start()
@ -25,7 +35,7 @@ func init() {
func newNotifDispatcher() *Dispatcher { func newNotifDispatcher() *Dispatcher {
return &Dispatcher{ return &Dispatcher{
task: task.GlobalTask("notif dispatcher"), task: task.GlobalTask("notif dispatcher"),
logCh: make(chan *logrus.Entry), logCh: make(chan *LogMessage),
providers: F.NewSet[Provider](), providers: F.NewSet[Provider](),
} }
} }
@ -34,11 +44,13 @@ func GetDispatcher() *Dispatcher {
return dispatcher return dispatcher
} }
func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.Error) { func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, error) {
name := configSubTask.Name() name := configSubTask.Name()
createFunc, ok := Providers[name] createFunc, ok := Providers[name]
if !ok { if !ok {
return nil, E.NotExist("provider", name) return nil, ErrUnknownNotifProvider.
Subject(name).
Withf(strutils.DoYouMean(utils.NearestField(name, Providers)))
} }
if provider, err := createFunc(cfg); err != nil { if provider, err := createFunc(cfg); err != nil {
return nil, err return nil, err
@ -53,7 +65,6 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.
func (disp *Dispatcher) start() { func (disp *Dispatcher) start() {
defer dispatcher.task.Finish("dispatcher stopped") defer dispatcher.task.Finish("dispatcher stopped")
defer close(dispatcher.logCh)
for { for {
select { select {
@ -65,36 +76,39 @@ func (disp *Dispatcher) start() {
} }
} }
func (disp *Dispatcher) dispatch(entry *logrus.Entry) { func (disp *Dispatcher) dispatch(msg *LogMessage) {
task := disp.task.Subtask("dispatch notif") task := disp.task.Subtask("dispatch notif")
defer task.Finish("notifs dispatched") defer task.Finish("notif dispatched")
errs := E.NewBuilder("errors sending notif") errs := E.NewBuilder(dispatchErr)
disp.providers.RangeAllParallel(func(p Provider) { disp.providers.RangeAllParallel(func(p Provider) {
if err := p.Send(task.Context(), entry); err != nil { if err := p.Send(task.Context(), msg); err != nil {
errs.Addf("%s: %s", p.Name(), err) errs.Add(E.PrependSubject(p.Name(), err))
} }
}) })
if err := errs.Build(); err != nil { if errs.HasError() {
logrus.Error("notif dispatcher failure: ", err) E.LogError(errs.About(), errs.Error())
} }
} }
// Levels implements logrus.Hook. // Run implements zerolog.Hook.
func (disp *Dispatcher) Levels() []logrus.Level { // func (disp *Dispatcher) Run(e *zerolog.Event, level zerolog.Level, message string) {
return []logrus.Level{ // if strings.HasPrefix(message, dispatchErr) { // prevent recursion
logrus.WarnLevel, // return
logrus.ErrorLevel, // }
logrus.FatalLevel, // switch level {
logrus.PanicLevel, // case zerolog.WarnLevel, zerolog.ErrorLevel, zerolog.FatalLevel, zerolog.PanicLevel:
} // disp.logCh <- &LogMessage{
} // Level: level,
// Message: message,
// }
// }
// }
// Fire implements logrus.Hook. func Notify(title, msg string) {
func (disp *Dispatcher) Fire(entry *logrus.Entry) error { dispatcher.logCh <- &LogMessage{
if disp.providers.Size() == 0 { Level: zerolog.InfoLevel,
return nil Title: title,
Message: msg,
} }
disp.logCh <- entry
return nil
} }

View file

@ -9,7 +9,7 @@ import (
"net/url" "net/url"
"github.com/gotify/server/v2/model" "github.com/gotify/server/v2/model"
"github.com/sirupsen/logrus" "github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
) )
@ -39,7 +39,7 @@ func newGotifyClient(cfg map[string]any) (Provider, E.Error) {
url, uErr := url.Parse(client.URL) url, uErr := url.Parse(client.URL)
if uErr != nil { if uErr != nil {
return nil, E.FailWith("parse url", uErr) return nil, E.Errorf("invalid gotify URL %s", client.URL)
} }
client.url = url client.url = url
@ -52,30 +52,23 @@ func (client *GotifyClient) Name() string {
} }
// Send implements NotifProvider. // Send implements NotifProvider.
func (client *GotifyClient) Send(ctx context.Context, entry *logrus.Entry) error { func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error {
var priority int var priority int
var title string
switch entry.Level { switch logMsg.Level {
case logrus.WarnLevel: case zerolog.WarnLevel:
priority = 2 priority = 2
title = "Warning" case zerolog.ErrorLevel:
case logrus.ErrorLevel:
priority = 5 priority = 5
title = "Error" case zerolog.FatalLevel, zerolog.PanicLevel:
case logrus.FatalLevel, logrus.PanicLevel:
priority = 8 priority = 8
title = "Critical"
default: default:
return nil return nil
} }
if subjects := FieldsAsTitle(entry); subjects != "" {
title = subjects + " " + title
}
msg := &GotifyMessage{ msg := &GotifyMessage{
Title: title, Title: logMsg.Title,
Message: entry.Message, Message: logMsg.Message,
Priority: priority, Priority: priority,
} }

View file

@ -1,21 +0,0 @@
package notif
import (
"fmt"
"strings"
"github.com/sirupsen/logrus"
U "github.com/yusing/go-proxy/internal/utils"
)
func FieldsAsTitle(entry *logrus.Entry) string {
if len(entry.Data) == 0 {
return ""
}
var parts []string
for k, v := range entry.Data {
parts = append(parts, fmt.Sprintf("%s: %s", k, v))
}
parts[0] = U.Title(parts[0])
return strings.Join(parts, ", ")
}

View file

@ -3,14 +3,13 @@ package notif
import ( import (
"context" "context"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
type ( type (
Provider interface { Provider interface {
Name() string Name() string
Send(ctx context.Context, entry *logrus.Entry) error Send(ctx context.Context, logMsg *LogMessage) error
} }
ProviderCreateFunc func(map[string]any) (Provider, E.Error) ProviderCreateFunc func(map[string]any) (Provider, E.Error)
ProviderConfig map[string]any ProviderConfig map[string]any

View file

@ -23,18 +23,18 @@ func ValidateEntry(m *RawEntry) (Entry, E.Error) {
scheme, err := T.NewScheme(m.Scheme) scheme, err := T.NewScheme(m.Scheme)
if err != nil { if err != nil {
return nil, err return nil, E.From(err)
} }
var entry Entry var entry Entry
e := E.NewBuilder("error validating entry") errs := E.NewBuilder("entry validation failed")
if scheme.IsStream() { if scheme.IsStream() {
entry = validateStreamEntry(m, e) entry = validateStreamEntry(m, errs)
} else { } else {
entry = validateRPEntry(m, scheme, e) entry = validateRPEntry(m, scheme, errs)
} }
if err := e.Build(); err != nil { if errs.HasError() {
return nil, err return nil, errs.Error()
} }
return entry, nil return entry, nil
} }

View file

@ -5,13 +5,14 @@ import (
"strings" "strings"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/watcher/health" "github.com/yusing/go-proxy/internal/watcher/health"
) )
@ -85,20 +86,20 @@ func (e *RawEntry) FillMissingFields() {
} else if !isDocker { } else if !isDocker {
pp = "80" pp = "80"
} else { } else {
logrus.Debugf("no port found for %s", e.Alias) logging.Debug().Msg("no port found for " + e.Alias)
} }
} }
// replace private port with public port if using public IP. // replace private port with public port if using public IP.
if e.Host == cont.PublicIP { if e.Host == cont.PublicIP {
if p, ok := cont.PrivatePortMapping[pp]; ok { if p, ok := cont.PrivatePortMapping[pp]; ok {
pp = U.PortString(p.PublicPort) pp = strutils.PortString(p.PublicPort)
} }
} }
// replace public port with private port if using private IP. // replace public port with private port if using private IP.
if e.Host == cont.PrivateIP { if e.Host == cont.PrivateIP {
if p, ok := cont.PublicPortMapping[pp]; ok { if p, ok := cont.PublicPortMapping[pp]; ok {
pp = U.PortString(p.PrivatePort) pp = strutils.PortString(p.PrivatePort)
} }
} }

View file

@ -53,7 +53,7 @@ func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config {
return rp.Idlewatcher return rp.Idlewatcher
} }
func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry { func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProxyEntry {
cont := m.Container cont := m.Container
if cont == nil { if cont == nil {
cont = docker.DummyContainer cont = docker.DummyContainer
@ -64,35 +64,26 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEnt
lb = nil lb = nil
} }
host, err := fields.ValidateHost(m.Host) host := E.Collect(errs, fields.ValidateHost, m.Host)
b.Add(err) port := E.Collect(errs, fields.ValidatePort, m.Port)
pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns)
url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port))
iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
port, err := fields.ValidatePort(m.Port) if errs.HasError() {
b.Add(err)
pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns)
b.Add(err)
url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if err != nil {
return nil return nil
} }
return &ReverseProxyEntry{ return &ReverseProxyEntry{
Raw: m, Raw: m,
Alias: fields.NewAlias(m.Alias), Alias: fields.Alias(m.Alias),
Scheme: s, Scheme: s,
URL: net.NewURL(url), URL: net.NewURL(url),
NoTLSVerify: m.NoTLSVerify, NoTLSVerify: m.NoTLSVerify,
PathPatterns: pathPatterns, PathPatterns: pathPats,
HealthCheck: m.HealthCheck, HealthCheck: m.HealthCheck,
LoadBalance: lb, LoadBalance: lb,
Middlewares: m.Middlewares, Middlewares: m.Middlewares,
Idlewatcher: idleWatcherCfg, Idlewatcher: iwCfg,
} }
} }

View file

@ -51,34 +51,25 @@ func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config {
return s.Idlewatcher return s.Idlewatcher
} }
func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry { func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry {
cont := m.Container cont := m.Container
if cont == nil { if cont == nil {
cont = docker.DummyContainer cont = docker.DummyContainer
} }
host, err := fields.ValidateHost(m.Host) host := E.Collect(errs, fields.ValidateHost, m.Host)
b.Add(err) port := E.Collect(errs, fields.ValidateStreamPort, m.Port)
scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme)
url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort))
idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container)
port, err := fields.ValidateStreamPort(m.Port) if errs.HasError() {
b.Add(err)
scheme, err := fields.ValidateStreamScheme(m.Scheme)
b.Add(err)
url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort)))
b.Add(err)
idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container)
b.Add(err)
if b.HasError() {
return nil return nil
} }
return &StreamEntry{ return &StreamEntry{
Raw: m, Raw: m,
Alias: fields.NewAlias(m.Alias), Alias: fields.Alias(m.Alias),
Scheme: *scheme, Scheme: *scheme,
URL: url, URL: url,
Host: host, Host: host,

View file

@ -1,6 +1,3 @@
package fields package fields
type ( type Alias string
Alias string
NewAlias = Alias
)

View file

@ -1,14 +1,10 @@
package fields package fields
import (
E "github.com/yusing/go-proxy/internal/error"
)
type ( type (
Host string Host string
Subdomain = Alias Subdomain = Alias
) )
func ValidateHost[String ~string](s String) (Host, E.Error) { func ValidateHost[String ~string](s String) (Host, error) {
return Host(s), nil return Host(s), nil
} }

View file

@ -1,6 +1,8 @@
package fields package fields
import ( import (
"errors"
"fmt"
"regexp" "regexp"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -13,12 +15,17 @@ type (
var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`)
func ValidatePathPattern(s string) (PathPattern, E.Error) { var (
ErrEmptyPathPattern = errors.New("path must not be empty")
ErrInvalidPathPattern = errors.New("invalid path pattern")
)
func ValidatePathPattern(s string) (PathPattern, error) {
if len(s) == 0 { if len(s) == 0 {
return "", E.Invalid("path", "must not be empty") return "", ErrEmptyPathPattern
} }
if !pathPattern.MatchString(s) { if !pathPattern.MatchString(s) {
return "", E.Invalid("path pattern", s) return "", fmt.Errorf("%w %q", ErrInvalidPathPattern, s)
} }
return PathPattern(s), nil return PathPattern(s), nil
} }
@ -27,13 +34,15 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.Error) {
if len(s) == 0 { if len(s) == 0 {
return []PathPattern{"/"}, nil return []PathPattern{"/"}, nil
} }
errs := E.NewBuilder("invalid path patterns")
pp := make(PathPatterns, len(s)) pp := make(PathPatterns, len(s))
for i, v := range s { for i, v := range s {
pattern, err := ValidatePathPattern(v) pattern, err := ValidatePathPattern(v)
if err != nil { if err != nil {
return nil, err errs.Add(err)
} else {
pp[i] = pattern
} }
pp[i] = pattern
} }
return pp, nil return pp, errs.Error()
} }

View file

@ -1,9 +1,9 @@
package fields package fields
import ( import (
"errors"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils/testing" U "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -38,10 +38,10 @@ var invalidPatterns = []string{
func TestPathPatternRegex(t *testing.T) { func TestPathPatternRegex(t *testing.T) {
for _, pattern := range validPatterns { for _, pattern := range validPatterns {
_, err := ValidatePathPattern(pattern) _, err := ValidatePathPattern(pattern)
U.ExpectNoError(t, err.Error()) U.ExpectNoError(t, err)
} }
for _, pattern := range invalidPatterns { for _, pattern := range invalidPatterns {
_, err := ValidatePathPattern(pattern) _, err := ValidatePathPattern(pattern)
U.ExpectError2(t, pattern, E.ErrInvalid, err.Error()) U.ExpectTrue(t, errors.Is(err, ErrInvalidPathPattern))
} }
} }

View file

@ -4,22 +4,25 @@ import (
"strconv" "strconv"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type Port int type Port int
func ValidatePort[String ~string](v String) (Port, E.Error) { var ErrPortOutOfRange = E.New("port out of range")
p, err := strconv.Atoi(string(v))
func ValidatePort[String ~string](v String) (Port, error) {
p, err := strutils.Atoi(string(v))
if err != nil { if err != nil {
return ErrPort, E.Invalid("port number", v).With(err) return ErrPort, err
} }
return ValidatePortInt(p) return ValidatePortInt(p)
} }
func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) { func ValidatePortInt[Int int | uint16](v Int) (Port, error) {
p := Port(v) p := Port(v)
if !p.inBound() { if !p.inBound() {
return ErrPort, E.OutOfRange("port", p) return ErrPort, ErrPortOutOfRange.Subject(strconv.Itoa(int(p)))
} }
return p, nil return p, nil
} }

View file

@ -6,12 +6,14 @@ import (
type Scheme string type Scheme string
func NewScheme[String ~string](s String) (Scheme, E.Error) { var ErrInvalidScheme = E.New("invalid scheme")
func NewScheme(s string) (Scheme, error) {
switch s { switch s {
case "http", "https", "tcp", "udp": case "http", "https", "tcp", "udp":
return Scheme(s), nil return Scheme(s), nil
} }
return "", E.Invalid("scheme", s) return "", ErrInvalidScheme.Subject(s)
} }
func (s Scheme) IsHTTP() bool { return s == "http" } func (s Scheme) IsHTTP() bool { return s == "http" }

View file

@ -3,7 +3,6 @@ package fields
import ( import (
"strings" "strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
@ -12,7 +11,9 @@ type StreamPort struct {
ProxyPort Port `json:"proxy"` ProxyPort Port `json:"proxy"`
} }
func ValidateStreamPort(p string) (_ StreamPort, err E.Error) { var ErrStreamPortTooManyColons = E.New("too many colons")
func ValidateStreamPort(p string) (StreamPort, error) {
split := strings.Split(p, ":") split := strings.Split(p, ":")
switch len(split) { switch len(split) {
@ -21,36 +22,14 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.Error) {
case 2: case 2:
break break
default: default:
err = E.Invalid("stream port", p).With("too many colons") return StreamPort{}, ErrStreamPortTooManyColons.Subject(p)
return
} }
listeningPort, err := ValidatePort(split[0]) listeningPort, lErr := ValidatePort(split[0])
if err != nil { proxyPort, pErr := ValidatePort(split[1])
err = err.Subject("listening port") if err := E.Join(lErr, pErr); err != nil {
return return StreamPort{}, err
}
proxyPort, err := ValidatePort(split[1])
if err.Is(E.ErrOutOfRange) {
err = err.Subject("proxy port")
return
} else if err != nil {
proxyPort, err = parseNameToPort(split[1])
if err != nil {
err = E.Invalid("proxy port", proxyPort)
return
}
} }
return StreamPort{listeningPort, proxyPort}, nil return StreamPort{listeningPort, proxyPort}, nil
} }
func parseNameToPort(name string) (Port, E.Error) {
port, ok := common.ServiceNamePortMapTCP[name]
if !ok {
return ErrPort, E.Invalid("service", name)
}
return Port(port), nil
}

View file

@ -1,9 +1,9 @@
package fields package fields
import ( import (
"strconv"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
@ -11,7 +11,6 @@ var validPorts = []string{
"1234:5678", "1234:5678",
"0:2345", "0:2345",
"2345", "2345",
"1234:postgres",
} }
var invalidPorts = []string{ var invalidPorts = []string{
@ -19,7 +18,6 @@ var invalidPorts = []string{
"123:", "123:",
"0:", "0:",
":1234", ":1234",
"1234:1234:1234",
"qwerty", "qwerty",
"asdfgh:asdfgh", "asdfgh:asdfgh",
"1234:asdfgh", "1234:asdfgh",
@ -32,17 +30,25 @@ var outOfRangePorts = []string{
"0:65536", "0:65536",
} }
var tooManyColonsPorts = []string{
"1234:1234:1234",
}
func TestStreamPort(t *testing.T) { func TestStreamPort(t *testing.T) {
for _, port := range validPorts { for _, port := range validPorts {
_, err := ValidateStreamPort(port) _, err := ValidateStreamPort(port)
ExpectNoError(t, err.Error()) ExpectNoError(t, err)
} }
for _, port := range invalidPorts { for _, port := range invalidPorts {
_, err := ValidateStreamPort(port) _, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrInvalid, err.Error()) ExpectError2(t, port, strconv.ErrSyntax, err)
} }
for _, port := range outOfRangePorts { for _, port := range outOfRangePorts {
_, err := ValidateStreamPort(port) _, err := ValidateStreamPort(port)
ExpectError2(t, port, E.ErrOutOfRange, err.Error()) ExpectError2(t, port, ErrPortOutOfRange, err)
}
for _, port := range tooManyColonsPorts {
_, err := ValidateStreamPort(port)
ExpectError2(t, port, ErrStreamPortTooManyColons, err)
} }
} }

View file

@ -12,22 +12,23 @@ type StreamScheme struct {
ProxyScheme Scheme `json:"proxy"` ProxyScheme Scheme `json:"proxy"`
} }
func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) { func ValidateStreamScheme(s string) (*StreamScheme, error) {
ss = &StreamScheme{} ss := &StreamScheme{}
parts := strings.Split(s, ":") parts := strings.Split(s, ":")
if len(parts) == 1 { if len(parts) == 1 {
parts = []string{s, s} parts = []string{s, s}
} else if len(parts) != 2 { } else if len(parts) != 2 {
return nil, E.Invalid("stream scheme", s) return nil, ErrInvalidScheme.Subject(s)
} }
ss.ListeningScheme, err = NewScheme(parts[0])
if err.HasError() { var lErr, pErr error
return nil, err ss.ListeningScheme, lErr = NewScheme(parts[0])
} ss.ProxyScheme, pErr = NewScheme(parts[1])
ss.ProxyScheme, err = NewScheme(parts[1])
if err.HasError() { if err := E.Join(lErr, pErr); err != nil {
return nil, err return nil, err
} }
return ss, nil return ss, nil
} }

View file

@ -0,0 +1,37 @@
package fields
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var (
validStreamSchemes = []string{
"tcp:tcp",
"tcp:udp",
"udp:tcp",
"udp:udp",
"tcp",
"udp",
}
invalidStreamSchemes = []string{
"tcp:tcp:",
"tcp:",
":udp:",
":udp",
"top",
}
)
func TestNewStreamScheme(t *testing.T) {
for _, s := range validStreamSchemes {
_, err := ValidateStreamScheme(s)
ExpectNoError(t, err)
}
for _, s := range invalidStreamSchemes {
_, err := ValidateStreamScheme(s)
ExpectError(t, ErrInvalidScheme, err)
}
}

View file

@ -7,13 +7,13 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/api/v1/errorpage"
"github.com/yusing/go-proxy/internal/docker/idlewatcher" "github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/loadbalancer"
"github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/net/http/middleware"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
"github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/proxy/entry"
PT "github.com/yusing/go-proxy/internal/proxy/fields" PT "github.com/yusing/go-proxy/internal/proxy/fields"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
@ -33,6 +33,8 @@ type (
rp *gphttp.ReverseProxy rp *gphttp.ReverseProxy
task task.Task task task.Task
l zerolog.Logger
} }
SubdomainKey = PT.Alias SubdomainKey = PT.Alias
@ -88,6 +90,10 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) {
ReverseProxyEntry: entry, ReverseProxyEntry: entry,
rp: rp, rp: rp,
task: task.DummyTask(), task: task.DummyTask(),
l: logger.With().
Str("type", string(entry.Scheme)).
Str("name", string(entry.Alias)).
Logger(),
} }
return r, nil return r, nil
} }
@ -107,11 +113,11 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
defer httpRoutesMu.Unlock() defer httpRoutesMu.Unlock()
if !entry.UseHealthCheck(r) && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { if !entry.UseHealthCheck(r) && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) {
logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled")
if r.HealthCheck == nil { if r.HealthCheck == nil {
r.HealthCheck = new(health.HealthCheckConfig) r.HealthCheck = new(health.HealthCheckConfig)
} }
r.HealthCheck.Disable = true r.HealthCheck.Disable = false
} }
switch { switch {
@ -143,7 +149,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error {
if r.HealthMon != nil { if r.HealthMon != nil {
if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil {
logrus.Warn(E.FailWith("health monitor", err)) E.LogWarn("health monitor error", err, &r.l)
} }
} }
@ -209,15 +215,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) {
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if err != nil { if err != nil {
if !middleware.ServeStaticErrorPageFile(w, r) { if !middleware.ServeStaticErrorPageFile(w, r) {
logrus.Error(E.Failure("request"). logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
Subjectf("%s %s", r.Method, r.URL.String()).
With(err))
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok { if ok {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
if _, err := w.Write(errorPage); err != nil { if _, err := w.Write(errorPage); err != nil {
logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err) logger.Err(err).Msg("failed to write error page")
} }
} else { } else {
http.Error(w, err.Error(), http.StatusNotFound) http.Error(w, err.Error(), http.StatusNotFound)

Some files were not shown because too many files have changed in this diff Show more