diff --git a/Makefile b/Makefile index 6aa0a9e..cba758c 100755 --- a/Makefile +++ b/Makefile @@ -65,3 +65,6 @@ debug-list-containers: ci-test: mkdir -p /tmp/artifacts act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" + +cloc: + cloc --not-match-f '_test.go$$' cmd internal pkg \ No newline at end of file diff --git a/cmd/main.go b/cmd/main.go index 7a0be76..6657753 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,7 +2,6 @@ package main import ( "encoding/json" - "io" "log" "net/http" "os" @@ -10,15 +9,14 @@ import ( "syscall" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/http/middleware" - "github.com/yusing/go-proxy/internal/notif" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/task" @@ -33,44 +31,26 @@ func main() { return } - l := logrus.WithField("module", "main") - timeFmt := "01-02 15:04:05" - fullTS := true - - if common.IsTrace { - logrus.SetLevel(logrus.TraceLevel) - timeFmt = "04:05" - fullTS = false - } else if common.IsDebug { - logrus.SetLevel(logrus.DebugLevel) - } - - if args.Command != common.CommandStart { - logrus.SetOutput(io.Discard) - } else { - logrus.SetFormatter(&logrus.TextFormatter{ - DisableSorting: true, - FullTimestamp: fullTS, - ForceColors: true, - TimestampFormat: timeFmt, - }) - logrus.Infof("go-proxy version %s", pkg.GetVersion()) - logrus.AddHook(notif.GetDispatcher()) - } - if args.Command == common.CommandReload { if err := query.ReloadServer(); err != nil { - log.Fatal(err) + E.LogFatal("server reload error", err) } - log.Print("ok") + logging.Info().Msg("ok") return } - // exit if only validate config + if args.Command == common.CommandStart { + logging.Info().Msgf("go-proxy version %s", pkg.GetVersion()) + logging.Trace().Msg("trace enabled") + // logging.AddHook(notif.GetDispatcher()) + } else { + logging.DiscardLogger() + } + if args.Command == common.CommandValidate { data, err := os.ReadFile(common.ConfigPath) if err == nil { - err = config.Validate(data).Error() + err = config.Validate(data) } if err != nil { log.Fatal("config error: ", err) @@ -88,7 +68,7 @@ func main() { var cfg *config.Config var err E.Error if cfg, err = config.Load(); err != nil { - logrus.Warn(err) + E.LogWarn("errors in config", err) } switch args.Command { @@ -145,10 +125,10 @@ func main() { autocert := config.GetAutoCertProvider() if autocert != nil { if err := autocert.Setup(); err != nil { - l.Fatal(err) + E.LogFatal("autocert setup error", err) } } else { - l.Info("autocert not configured") + logging.Info().Msg("autocert not configured") } proxyServer := server.InitProxyServer(server.Options{ @@ -174,7 +154,7 @@ func main() { <-sig // grafully shutdown - logrus.Info("shutting down") + logging.Info().Msg("shutting down") task.CancelGlobalContext() task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown)) } @@ -182,15 +162,15 @@ func main() { func prepareDirectory(dir string) { if _, err := os.Stat(dir); os.IsNotExist(err) { if err = os.MkdirAll(dir, 0o755); err != nil { - logrus.Fatalf("failed to create directory %s: %v", dir, err) + logging.Fatal().Msgf("failed to create directory %s: %v", dir, err) } } } func printJSON(obj any) { - j, err := E.Check(json.MarshalIndent(obj, "", " ")) + j, err := json.MarshalIndent(obj, "", " ") if err != nil { - logrus.Fatal(err) + logging.Fatal().Err(err).Send() } rawLogger := log.New(os.Stdout, "", 0) rawLogger.Printf("%s", j) // raw output for convenience using "jq" diff --git a/go.mod b/go.mod index 5612f4d..3370a40 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,8 @@ require ( github.com/go-acme/lego/v4 v4.19.2 github.com/gotify/server/v2 v2.5.0 github.com/puzpuzpuz/xsync/v3 v3.4.0 + github.com/rs/zerolog v1.33.0 github.com/santhosh-tekuri/jsonschema v1.2.4 - github.com/sirupsen/logrus v1.9.3 golang.org/x/net v0.30.0 golang.org/x/text v0.19.0 gopkg.in/yaml.v3 v3.0.1 @@ -20,7 +20,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cloudflare/cloudflare-go v0.107.0 // indirect + github.com/cloudflare/cloudflare-go v0.108.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect @@ -33,6 +33,8 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.62 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/term v0.5.0 // indirect @@ -42,6 +44,7 @@ require ( github.com/ovh/go-ovh v1.6.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect go.opentelemetry.io/otel v1.31.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect diff --git a/go.sum b/go.sum index f0d259b..62f4b45 100644 --- a/go.sum +++ b/go.sum @@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc= -github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc= +github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0= +github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -40,6 +41,7 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -61,6 +63,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= @@ -88,6 +96,9 @@ github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPK github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis= github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -140,6 +151,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/api/handler.go b/internal/api/handler.go index 8f66864..5ee66a0 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -6,7 +6,6 @@ import ( "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" - "github.com/yusing/go-proxy/internal/api/v1/errorpage" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -38,7 +37,6 @@ func NewHandler() http.Handler { mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("GET", "/v1/stats", v1.Stats) mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) - mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc()) return mux } @@ -50,7 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { host, _, _ := net.SplitHostPort(r.RemoteAddr) if host != "127.0.0.1" && host != "localhost" && host != "[::1]" { - Logger.Warnf("blocked API request from %s", host) + LogWarn(r).Msgf("blocked API request from %s", host) http.Error(w, "forbidden", http.StatusForbidden) return } diff --git a/internal/api/v1/errorpage/http_handler.go b/internal/api/v1/errorpage/http_handler.go deleted file mode 100644 index 2da9372..0000000 --- a/internal/api/v1/errorpage/http_handler.go +++ /dev/null @@ -1,31 +0,0 @@ -package errorpage - -import ( - "net/http" - - . "github.com/yusing/go-proxy/internal/api/v1/utils" -) - -func GetHandleFunc() http.HandlerFunc { - setup() - return serveHTTP -} - -func serveHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if r.URL.Path == "/" { - http.Error(w, "invalid path", http.StatusNotFound) - return - } - content, ok := fileContentMap.Load(r.URL.Path) - if !ok { - http.Error(w, "404 not found", http.StatusNotFound) - return - } - if _, err := w.Write(content); err != nil { - HandleErr(w, r, err) - } -} diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index 084b4f1..acaeabb 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -39,15 +39,16 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { return } - var validateErr E.Error + var valErr E.Error if filename == common.ConfigFileName { - validateErr = config.Validate(content) + valErr = config.Validate(content) } else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) { - validateErr = provider.Validate(content) + valErr = provider.Validate(content) } + // no validation for include files - if validateErr != nil { - U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest) + if valErr != nil { + U.RespondJSON(w, r, valErr, http.StatusBadRequest) return } diff --git a/internal/api/v1/query/query.go b/internal/api/v1/query/query.go index cd30cc7..7757004 100644 --- a/internal/api/v1/query/query.go +++ b/internal/api/v1/query/query.go @@ -20,16 +20,13 @@ func ReloadServer() E.Error { } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode) - b, err := io.ReadAll(resp.Body) + failure := E.Errorf("server reload status %v", resp.StatusCode) + body, err := io.ReadAll(resp.Body) if err != nil { - return failure.Extraf("unable to read response body: %s", err) + return failure.With(err) } - reloadErr, ok := E.FromJSON(b) - if ok { - return E.Join("reload success, but server returned error", reloadErr) - } - return failure.Extraf("unable to read response body") + reloadErr := string(body) + return failure.Withf(reloadErr) } return nil } @@ -42,7 +39,7 @@ func List[T any](what string) (_ T, outErr E.Error) { } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode) + outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode) return } var res T diff --git a/internal/api/v1/reload.go b/internal/api/v1/reload.go index ffd0609..42a4198 100644 --- a/internal/api/v1/reload.go +++ b/internal/api/v1/reload.go @@ -9,8 +9,8 @@ import ( func Reload(w http.ResponseWriter, r *http.Request) { if err := config.Reload(); err != nil { - U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError) - } else { - w.WriteHeader(http.StatusOK) + U.HandleErr(w, r, err) + return } + U.WriteBody(w, []byte("OK")) } diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index c86d325..e40c364 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -11,7 +11,7 @@ import ( "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/server" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) func Stats(w http.ResponseWriter, r *http.Request) { @@ -23,7 +23,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses)) if len(originPats) == 0 { - U.Logger.Warnf("no match domains configured, accepting websocket request from all origins") + U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins") originPats = []string{"*"} } else { for i, domain := range config.Value().MatchDomains { @@ -38,7 +38,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { OriginPatterns: originPats, }) if err != nil { - U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err) + U.LogError(r).Err(err).Msg("failed to upgrade websocket") return } /* trunk-ignore(golangci-lint/errcheck) */ @@ -53,7 +53,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { for range ticker.C { stats := getStats() if err := wsjson.Write(ctx, conn, stats); err != nil { - U.Logger.Errorf("/stats/ws failed to write JSON: %s", err) + U.LogError(r).Msg("failed to write JSON") return } } @@ -62,6 +62,6 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { func getStats() map[string]any { return map[string]any{ "proxies": config.Statistics(), - "uptime": utils.FormatDuration(server.GetProxyServer().Uptime()), + "uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()), } } diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go index 957e6cd..3c55d95 100644 --- a/internal/api/v1/utils/error.go +++ b/internal/api/v1/utils/error.go @@ -1,37 +1,31 @@ package utils import ( - "errors" - "fmt" "net/http" - "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" ) -var Logger = logrus.WithField("module", "api") - func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) { if origErr == nil { return } - err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL) - Logger.Error(err) + LogError(r).Msg(origErr.Error()) if len(code) > 0 { - http.Error(w, err.String(), code[0]) + http.Error(w, origErr.Error(), code[0]) return } - http.Error(w, err.String(), http.StatusInternalServerError) + http.Error(w, origErr.Error(), http.StatusInternalServerError) } func ErrMissingKey(k string) error { - return errors.New("missing key '" + k + "' in query or request body") + return E.New("missing key '" + k + "' in query or request body") } func ErrInvalidKey(k string) error { - return errors.New("invalid key '" + k + "' in query or request body") + return E.New("invalid key '" + k + "' in query or request body") } func ErrNotFound(k, v string) error { - return fmt.Errorf("key %q with value %q not found", k, v) + return E.Errorf("key %q with value %q not found", k, v) } diff --git a/internal/api/v1/utils/http_client.go b/internal/api/v1/utils/http_client.go index 0cb4ebe..48d743b 100644 --- a/internal/api/v1/utils/http_client.go +++ b/internal/api/v1/utils/http_client.go @@ -9,12 +9,11 @@ import ( ) var ( - HTTPClient = &http.Client{ + httpClient = &http.Client{ Timeout: common.ConnectionTimeout, Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, DisableKeepAlives: true, - ForceAttemptHTTP2: true, + ForceAttemptHTTP2: false, DialContext: (&net.Dialer{ Timeout: common.DialTimeout, KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives @@ -23,7 +22,7 @@ var ( }, } - Get = HTTPClient.Get - Post = HTTPClient.Post - Head = HTTPClient.Head + Get = httpClient.Get + Post = httpClient.Post + Head = httpClient.Head ) diff --git a/internal/api/v1/utils/logging.go b/internal/api/v1/utils/logging.go new file mode 100644 index 0000000..cdf4222 --- /dev/null +++ b/internal/api/v1/utils/logging.go @@ -0,0 +1,18 @@ +package utils + +import ( + "net/http" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/logging" +) + +func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { + return logging.WithLevel(level).Str("module", "api"). + Str("method", r.Method). + Str("path", r.RequestURI) +} + +func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) } +func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) } +func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) } diff --git a/internal/api/v1/utils/utils.go b/internal/api/v1/utils/utils.go index c687b44..01015a2 100644 --- a/internal/api/v1/utils/utils.go +++ b/internal/api/v1/utils/utils.go @@ -3,6 +3,8 @@ package utils import ( "encoding/json" "net/http" + + "github.com/yusing/go-proxy/internal/logging" ) func WriteBody(w http.ResponseWriter, body []byte) { @@ -11,15 +13,25 @@ func WriteBody(w http.ResponseWriter, body []byte) { } } -func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool { +func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) { if len(code) > 0 { w.WriteHeader(code[0]) } w.Header().Set("Content-Type", "application/json") - j, err := json.MarshalIndent(data, "", " ") - if err != nil { - HandleErr(w, r, err) - return false + var j []byte + var err error + + switch data := data.(type) { + case string: + j = []byte(`"` + data + `"`) + case []byte: + j = data + default: + j, err = json.MarshalIndent(data, "", " ") + if err != nil { + logging.Panic().Err(err).Msg("failed to marshal json") + return false + } } _, err = w.Write(j) if err != nil { diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 351e479..670db93 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -8,12 +8,21 @@ import ( "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/lego" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/config/types" ) type Config types.AutoCertConfig +var ( + ErrMissingDomain = E.New("missing field 'domains'") + ErrMissingEmail = E.New("missing field 'email'") + ErrMissingProvider = E.New("missing field 'provider'") + ErrUnknownProvider = E.New("unknown provider") +) + func NewConfig(cfg *types.AutoCertConfig) *Config { if cfg.CertPath == "" { cfg.CertPath = CertFileDefault @@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config { return (*Config)(cfg) } -func (cfg *Config) GetProvider() (provider *Provider, res E.Error) { - b := E.NewBuilder("unable to initialize autocert") - defer b.To(&res) +func (cfg *Config) GetProvider() (*Provider, E.Error) { + b := E.NewBuilder("autocert errors") if cfg.Provider != ProviderLocal { if len(cfg.Domains) == 0 { - b.Addf("%s", "no domains specified") + b.Add(ErrMissingDomain) } if cfg.Provider == "" { - b.Addf("%s", "no provider specified") + b.Add(ErrMissingProvider) } if cfg.Email == "" { - b.Addf("%s", "no email specified") + b.Add(ErrMissingEmail) } // check if provider is implemented _, ok := providersGenMap[cfg.Provider] if !ok { - b.Addf("unknown provider: %q", cfg.Provider) + b.Add(ErrUnknownProvider. + Subject(cfg.Provider). + Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap)))) } } if b.HasError() { - return + return nil, b.Error() } - privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) - if err.HasError() { - b.Add(E.FailWith("generate private key", err)) - return + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + b.Addf("generate private key: %w", err) + return nil, b.Error() } user := &User{ @@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.Error) { legoCfg := lego.NewConfig(user) legoCfg.Certificate.KeyType = certcrypto.RSA2048 - provider = &Provider{ + return &Provider{ cfg: cfg, user: user, legoCfg: legoCfg, - } - - return + }, nil } diff --git a/internal/autocert/constants.go b/internal/autocert/constants.go index 19b726f..ffe0cc3 100644 --- a/internal/autocert/constants.go +++ b/internal/autocert/constants.go @@ -7,7 +7,6 @@ import ( "github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/duckdns" "github.com/go-acme/lego/v4/providers/dns/ovh" - "github.com/sirupsen/logrus" ) const ( @@ -36,5 +35,3 @@ var providersGenMap = map[string]ProviderGenerator{ var ( ErrGetCertFailure = errors.New("get certificate failed") ) - -var logger = logrus.WithField("module", "autocert") diff --git a/internal/autocert/logger.go b/internal/autocert/logger.go new file mode 100644 index 0000000..9339fa7 --- /dev/null +++ b/internal/autocert/logger.go @@ -0,0 +1,5 @@ +package autocert + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "autocert").Logger() diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 8ab2aad..4f13673 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -17,6 +17,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( @@ -57,25 +58,20 @@ func (p *Provider) GetExpiries() CertExpiries { return p.certExpiries } -func (p *Provider) ObtainCert() (res E.Error) { - b := E.NewBuilder("failed to obtain certificate") - defer b.To(&res) - +func (p *Provider) ObtainCert() E.Error { if p.cfg.Provider == ProviderLocal { return nil } if p.client == nil { - if err := p.initClient(); err.HasError() { - b.Add(E.FailWith("init autocert client", err)) - return + if err := p.initClient(); err != nil { + return err } } if p.user.Registration == nil { - if err := p.registerACME(); err.HasError() { - b.Add(E.FailWith("register ACME", err)) - return + if err := p.registerACME(); err != nil { + return E.From(err) } } @@ -84,27 +80,23 @@ func (p *Provider) ObtainCert() (res E.Error) { Domains: p.cfg.Domains, Bundle: true, } - cert, err := E.Check(client.Certificate.Obtain(req)) - if err.HasError() { - b.Add(err) - return + cert, err := client.Certificate.Obtain(req) + if err != nil { + return E.From(err) } - if err = p.saveCert(cert); err.HasError() { - b.Add(E.FailWith("save certificate", err)) - return + if err = p.saveCert(cert); err != nil { + return E.From(err) } - tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey)) - if err.HasError() { - b.Add(E.FailWith("parse obtained certificate", err)) - return + tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey) + if err != nil { + return E.From(err) } expiries, err := getCertExpiries(&tlsCert) - if err.HasError() { - b.Add(E.FailWith("get certificate expiry", err)) - return + if err != nil { + return E.From(err) } p.tlsCert = &tlsCert p.certExpiries = expiries @@ -113,21 +105,22 @@ func (p *Provider) ObtainCert() (res E.Error) { } func (p *Provider) LoadCert() E.Error { - cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)) - if err.HasError() { - return err + cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath) + if err != nil { + return E.Errorf("load SSL certificate: %w", err) } expiries, err := getCertExpiries(&cert) - if err.HasError() { - return err + if err != nil { + return E.Errorf("parse SSL certificate: %w", err) } p.tlsCert = &cert p.certExpiries = expiries - logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn()))) + logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn()))) return p.renewIfNeeded() } +// ShouldRenewOn returns the time at which the certificate should be renewed. func (p *Provider) ShouldRenewOn() time.Time { for _, expiry := range p.certExpiries { return expiry.AddDate(0, -1, 0) // 1 month before @@ -150,8 +143,8 @@ func (p *Provider) ScheduleRenewal() { case <-task.Context().Done(): return case <-ticker.C: // check every 5 seconds - if err := p.renewIfNeeded(); err.HasError() { - logger.Warn(err) + if err := p.renewIfNeeded(); err != nil { + E.LogWarn("cert renew failed", err, &logger) } } } @@ -159,31 +152,32 @@ func (p *Provider) ScheduleRenewal() { } func (p *Provider) initClient() E.Error { - legoClient, err := E.Check(lego.NewClient(p.legoCfg)) - if err.HasError() { - return E.FailWith("create lego client", err) + legoClient, err := lego.NewClient(p.legoCfg) + if err != nil { + return E.From(err) } - legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options) - if err.HasError() { - return E.FailWith("create lego provider", err) + generator := providersGenMap[p.cfg.Provider] + legoProvider, pErr := generator(p.cfg.Options) + if pErr != nil { + return pErr } - err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) - if err.HasError() { - return E.FailWith("set challenge provider", err) + err = legoClient.Challenge.SetDNS01Provider(legoProvider) + if err != nil { + return E.From(err) } p.client = legoClient return nil } -func (p *Provider) registerACME() E.Error { +func (p *Provider) registerACME() error { if p.user.Registration != nil { return nil } - reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) - if err.HasError() { + reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) + if err != nil { return err } p.user.Registration = reg @@ -191,26 +185,27 @@ func (p *Provider) registerACME() E.Error { return nil } -func (p *Provider) saveCert(cert *certificate.Resource) E.Error { +func (p *Provider) saveCert(cert *certificate.Resource) error { /* This should have been done in setup but double check is always a good choice.*/ _, err := os.Stat(path.Dir(p.cfg.CertPath)) if err != nil { if os.IsNotExist(err) { if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil { - return E.FailWith("create cert directory", err) + return err } } else { - return E.FailWith("stat cert directory", err) + return err } } err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- if err != nil { - return E.FailWith("write key file", err) + return err } + err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r-- if err != nil { - return E.FailWith("write cert file", err) + return err } return nil } @@ -232,7 +227,7 @@ func (p *Provider) certState() CertState { sort.Strings(certDomains) if !reflect.DeepEqual(certDomains, wantedDomains) { - logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) + logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) return CertStateMismatch } @@ -246,25 +241,25 @@ func (p *Provider) renewIfNeeded() E.Error { switch p.certState() { case CertStateExpired: - logger.Info("certs expired, renewing") + logger.Info().Msg("certs expired, renewing") case CertStateMismatch: - logger.Info("cert domains mismatch with config, renewing") + logger.Info().Msg("cert domains mismatch with config, renewing") default: return nil } - if err := p.ObtainCert(); err.HasError() { - return E.FailWith("renew certificate", err) + if err := p.ObtainCert(); err != nil { + return err } return nil } -func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) { +func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { r := make(CertExpiries, len(cert.Certificate)) for _, cert := range cert.Certificate { - x509Cert, err := E.Check(x509.ParseCertificate(cert)) - if err.HasError() { - return nil, E.FailWith("parse certificate", err) + x509Cert, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err } if x509Cert.IsCA { continue @@ -284,13 +279,10 @@ func providerGenerator[CT any, PT challenge.Provider]( return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) { cfg := defaultCfg() err := U.Deserialize(opt, cfg) - if err.HasError() { + if err != nil { return nil, err } - p, err := E.Check(newProvider(cfg)) - if err.HasError() { - return nil, err - } - return p, nil + p, pErr := newProvider(cfg) + return p, E.From(pErr) } } diff --git a/internal/autocert/provider_test/ovh_test.go b/internal/autocert/provider_test/ovh_test.go index e9d778f..35bf059 100644 --- a/internal/autocert/provider_test/ovh_test.go +++ b/internal/autocert/provider_test/ovh_test.go @@ -45,6 +45,6 @@ oauth2_config: testYaml = testYaml[1:] // remove first \n opt := make(map[string]any) ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt)) - ExpectNoError(t, U.Deserialize(opt, cfg).Error()) + ExpectNoError(t, U.Deserialize(opt, cfg)) ExpectDeepEqual(t, cfg, cfgExpected) } diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 62640b1..c65aa35 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -11,7 +11,7 @@ func (p *Provider) Setup() (err E.Error) { if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist return err } - logger.Debug("obtaining cert due to error loading cert") + logger.Debug().Msg("obtaining cert due to error loading cert") if err = p.ObtainCert(); err != nil { return err } @@ -20,7 +20,7 @@ func (p *Provider) Setup() (err E.Error) { p.ScheduleRenewal() for _, expiry := range p.GetExpiries() { - logger.Infof("certificate expire on %s", expiry) + logger.Info().Msg("certificate expire on " + expiry.String()) break } diff --git a/internal/common/args.go b/internal/common/args.go index b0dd597..35563e7 100644 --- a/internal/common/args.go +++ b/internal/common/args.go @@ -3,8 +3,7 @@ package common import ( "flag" "fmt" - - "github.com/sirupsen/logrus" + "log" ) type Args struct { @@ -44,7 +43,7 @@ func GetArgs() Args { flag.Parse() args.Command = flag.Arg(0) if err := validateArg(args.Command); err != nil { - logrus.Fatal(err) + log.Fatalf("invalid command: %s", err) } return args } @@ -55,5 +54,5 @@ func validateArg(arg string) error { return nil } } - return fmt.Errorf("invalid command: %s", arg) + return fmt.Errorf("invalid command %q", arg) } diff --git a/internal/common/env.go b/internal/common/env.go index 13d71fe..a8ad9b7 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -7,8 +7,6 @@ import ( "os" "strconv" "strings" - - "github.com/sirupsen/logrus" ) var ( @@ -40,7 +38,7 @@ func GetEnvBool(key string, defaultValue bool) bool { } b, err := strconv.ParseBool(value) if err != nil { - log.Fatalf("Invalid boolean value: %s", value) + log.Fatalf("env %s: invalid boolean value: %s", key, value) } return b } @@ -57,7 +55,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str addr = GetEnv(key, defaultValue) host, port, err := net.SplitHostPort(addr) if err != nil { - logrus.Fatalf("Invalid address: %s", addr) + log.Fatalf("env %s: invalid address: %s", key, addr) } if host == "" { host = "localhost" diff --git a/internal/config/config.go b/internal/config/config.go index 51a42cf..61e1f7e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,14 +2,15 @@ package config import ( "os" + "strconv" "sync" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/route" proxy "github.com/yusing/go-proxy/internal/route/provider" @@ -31,7 +32,7 @@ type Config struct { var ( instance *Config cfgWatcher watcher.Watcher - logger = logrus.WithField("module", "config") + logger = logging.With().Str("module", "config").Logger() reloadMu sync.Mutex ) @@ -80,7 +81,7 @@ func WatchChanges() { configEventFlushInterval, OnConfigChange, func(err E.Error) { - logger.Error(err) + E.LogError("config reload error", err, &logger) }, ) eventQueue.Start(cfgWatcher.Events(task.Context())) @@ -93,15 +94,16 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) { // just reload once and check the last event switch ev[len(ev)-1].Action { case events.ActionFileRenamed: - logger.Warn(cfgRenameWarn) + logger.Warn().Msg(cfgRenameWarn) return case events.ActionFileDeleted: - logger.Warn(cfgDeleteWarn) + logger.Warn().Msg(cfgDeleteWarn) return } if err := Reload(); err != nil { - logger.Error(err) + // recovered in event queue + panic(err) } } @@ -138,63 +140,57 @@ func (cfg *Config) Task() task.Task { } func (cfg *Config) StartProxyProviders() { - b := E.NewBuilder("errors starting providers") - cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { - b.Add(p.Start(cfg.task.Subtask(p.String()))) - }) + errs := cfg.providers.CollectErrorsParallel( + func(_ string, p *proxy.Provider) error { + subtask := cfg.task.Subtask(p.String()) + return p.Start(subtask) + }) - if b.HasError() { - logger.Error(b.Build()) + if err := E.Join(errs...); err != nil { + E.LogError("route provider errors", err, &logger) } } -func (cfg *Config) load() (res E.Error) { - errs := E.NewBuilder("errors loading config") - defer errs.To(&res) +func (cfg *Config) load() E.Error { + const errMsg = "config load error" - logger.Debug("loading config") - defer logger.Debug("loaded config") - - data, err := E.Check(os.ReadFile(common.ConfigPath)) + data, err := os.ReadFile(common.ConfigPath) if err != nil { - errs.Add(E.FailWith("read config", err)) - logrus.Fatal(errs.Build()) + E.LogFatal(errMsg, err, &logger) } if !common.NoSchemaValidation { - if err = Validate(data); err != nil { - errs.Add(E.FailWith("schema validation", err)) - logrus.Fatal(errs.Build()) + if err := Validate(data); err != nil { + E.LogFatal(errMsg, err, &logger) } } model := types.DefaultConfig() if err := E.From(yaml.Unmarshal(data, model)); err != nil { - errs.Add(E.FailWith("parse config", err)) - logrus.Fatal(errs.Build()) + E.LogFatal(errMsg, err, &logger) } // errors are non fatal below + errs := E.NewBuilder(errMsg) errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initAutoCert(&model.AutoCert)) errs.Add(cfg.loadRouteProviders(&model.Providers)) cfg.value = model route.SetFindMuxDomains(model.MatchDomains) - return + return errs.Error() } func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) { if len(notifCfgMap) == 0 { return } - errs := E.NewBuilder("errors initializing notification providers") - + errs := E.NewBuilder("notification providers load errors") for name, notifCfg := range notifCfgMap { _, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg) errs.Add(err) } - return errs.Build() + return errs.Error() } func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) { @@ -203,40 +199,45 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) } cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() - if err != nil { - err = E.FailWith("autocert provider", err) - } return } -func (cfg *Config) loadRouteProviders(providers *types.Providers) (outErr E.Error) { +func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error { subtask := cfg.task.Subtask("load route providers") defer subtask.Finish("done") - errs := E.NewBuilder("errors loading route providers") - results := E.NewBuilder("loaded providers") - defer errs.To(&outErr) + errs := E.NewBuilder("route provider errors") + results := E.NewBuilder("loaded route providers") + lenLongestName := 0 for _, filename := range providers.Files { p, err := proxy.NewFileProvider(filename) if err != nil { - errs.Add(err) + errs.Add(E.PrependSubject(filename, err)) continue } cfg.providers.Store(p.GetName(), p) - errs.Add(p.LoadRoutes().Subject(filename)) - results.Addf("%d routes from %s", p.NumRoutes(), p.String()) + if len(p.GetName()) > lenLongestName { + lenLongestName = len(p.GetName()) + } } for name, dockerHost := range providers.Docker { p, err := proxy.NewDockerProvider(name, dockerHost) if err != nil { - errs.Add(err.Subjectf("%s (%s)", name, dockerHost)) + errs.Add(E.PrependSubject(name, err)) continue } cfg.providers.Store(p.GetName(), p) - errs.Add(p.LoadRoutes().Subject(p.GetName())) - results.Addf("%d routes from %s", p.NumRoutes(), p.String()) + if len(p.GetName()) > lenLongestName { + lenLongestName = len(p.GetName()) + } } - logger.Info(results.Build()) - return + cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { + if err := p.LoadRoutes(); err != nil { + errs.Add(err.Subject(p.String())) + } + results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes()) + }) + logger.Info().Msg(results.String()) + return errs.Error() } diff --git a/internal/config/query.go b/internal/config/query.go index 813b8c9..8933cf3 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -9,8 +9,8 @@ import ( "github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/route" proxy "github.com/yusing/go-proxy/internal/route/provider" - U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" ) func DumpEntries() map[string]*entry.RawEntry { @@ -61,7 +61,7 @@ func HomepageConfig() homepage.Config { } if item.Name == "" { - item.Name = U.Title( + item.Name = strutils.Title( strings.ReplaceAll( strings.ReplaceAll(alias, "-", " "), "_", " ", diff --git a/internal/docker/client.go b/internal/docker/client.go index d4cb5f4..f1ecf5b 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -1,14 +1,15 @@ package docker import ( + "errors" "net/http" "sync" "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -22,7 +23,7 @@ type ( key string refCount *U.RefCount - l logrus.FieldLogger + l zerolog.Logger } ) @@ -70,7 +71,7 @@ func (c *SharedClient) Close() error { // Returns: // - Client: the Docker client connection. // - error: an error if the connection failed. -func ConnectClient(host string) (Client, E.Error) { +func ConnectClient(host string) (Client, error) { clientMapMu.Lock() defer clientMapMu.Unlock() @@ -85,13 +86,13 @@ func ConnectClient(host string) (Client, E.Error) { switch host { case "": - return nil, E.Invalid("docker host", "empty") + return nil, errors.New("empty docker host") case common.DockerHostFromEnv: opt = clientOptEnvHost default: - helper, err := E.Check(connhelper.GetConnectionHelper(host)) - if err.HasError() { - return nil, E.UnexpectedError(err.Error()) + helper, err := connhelper.GetConnectionHelper(host) + if err != nil { + logging.Panic().Err(err).Msg("failed to get connection helper") } if helper != nil { httpClient := &http.Client{ @@ -113,9 +114,9 @@ func ConnectClient(host string) (Client, E.Error) { } } - client, err := E.Check(client.NewClientWithOpts(opt...)) + client, err := client.NewClientWithOpts(opt...) - if err.HasError() { + if err != nil { return nil, err } @@ -123,9 +124,9 @@ func ConnectClient(host string) (Client, E.Error) { Client: client, key: host, refCount: U.NewRefCounter(), - l: logger.WithField("docker_client", client.DaemonHost()), + l: logger.With().Str("address", client.DaemonHost()).Logger(), } - c.l.Debugf("client connected") + c.l.Trace().Msg("client connected") clientMap.Store(host, c) @@ -135,7 +136,7 @@ func ConnectClient(host string) (Client, E.Error) { if c.Connected() { c.Client.Close() - c.l.Debugf("client closed") + c.l.Trace().Msg("client closed") } }() return c, nil diff --git a/internal/docker/container.go b/internal/docker/container.go index d0afa9f..f87535f 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -6,8 +6,8 @@ import ( "strings" "github.com/docker/docker/api/types" - "github.com/sirupsen/logrus" U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( @@ -59,7 +59,7 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { NetworkMode: c.HostConfig.NetworkMode, Aliases: helper.getAliases(), - IsExcluded: U.ParseBool(helper.getDeleteLabel(LabelExclude)), + IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), IsExplicit: isExplicit, IsDatabase: helper.isDatabase(), IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), @@ -120,7 +120,7 @@ func (c *Container) setPublicIP() { } url, err := url.Parse(c.DockerHost) if err != nil { - logrus.Errorf("invalid docker host %q: %v\nfalling back to 127.0.0.1", c.DockerHost, err) + logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost) c.PublicIP = "127.0.0.1" return } diff --git a/internal/docker/container_helper.go b/internal/docker/container_helper.go index 12d6cc7..d3f3a63 100644 --- a/internal/docker/container_helper.go +++ b/internal/docker/container_helper.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/docker/docker/api/types" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type containerHelper struct { @@ -23,7 +23,7 @@ func (c containerHelper) getDeleteLabel(label string) string { func (c containerHelper) getAliases() []string { if l := c.getDeleteLabel(LabelAliases); l != "" { - return U.CommaSeperatedList(l) + return strutils.CommaSeperatedList(l) } return []string{c.getName()} } @@ -44,7 +44,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping { if v.PublicPort == 0 { continue } - res[U.PortString(v.PublicPort)] = v + res[strutils.PortString(v.PublicPort)] = v } return res } @@ -52,7 +52,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping { func (c containerHelper) getPrivatePortMapping() PortMapping { res := make(PortMapping) for _, v := range c.Ports { - res[U.PortString(v.PrivatePort)] = v + res[strutils.PortString(v.PrivatePort)] = v } return res } diff --git a/internal/docker/idlewatcher/types/config.go b/internal/docker/idlewatcher/types/config.go index a12829f..95201c7 100644 --- a/internal/docker/idlewatcher/types/config.go +++ b/internal/docker/idlewatcher/types/config.go @@ -1,6 +1,7 @@ package types import ( + "errors" "time" "github.com/yusing/go-proxy/internal/docker" @@ -30,7 +31,7 @@ const ( StopMethodKill StopMethod = "kill" ) -func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { +func ValidateConfig(cont *docker.Container) (*Config, E.Error) { if cont == nil { return nil, nil } @@ -44,26 +45,16 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { }, nil } - b := E.NewBuilder("invalid idlewatcher config") - defer b.To(&res) + errs := E.NewBuilder("invalid idlewatcher config") - idleTimeout, err := validateDurationPostitive(cont.IdleTimeout) - b.Add(err.Subjectf("%s", "idle_timeout")) + idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout) + wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout) + stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout) + stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod) + signal := E.Collect(errs, validateSignal, cont.StopSignal) - wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout) - b.Add(err.Subjectf("%s", "wake_timeout")) - - stopTimeout, err := validateDurationPostitive(cont.StopTimeout) - b.Add(err.Subjectf("%s", "stop_timeout")) - - stopMethod, err := validateStopMethod(cont.StopMethod) - b.Add(err) - - signal, err := validateSignal(cont.StopSignal) - b.Add(err) - - if err := b.Build(); err != nil { - return + if errs.HasError() { + return nil, errs.Error() } return &Config{ @@ -80,33 +71,33 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { }, nil } -func validateDurationPostitive(value string) (time.Duration, E.Error) { +func validateDurationPostitive(value string) (time.Duration, error) { d, err := time.ParseDuration(value) if err != nil { - return 0, E.Invalid("duration", value).With(err) + return 0, err } if d < 0 { - return 0, E.Invalid("duration", "negative value") + return 0, errors.New("duration must be positive") } return d, nil } -func validateSignal(s string) (Signal, E.Error) { +func validateSignal(s string) (Signal, error) { switch s { case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", "INT", "TERM", "HUP", "QUIT": return Signal(s), nil } - return "", E.Invalid("signal", s) + return "", errors.New("invalid signal " + s) } -func validateStopMethod(s string) (StopMethod, E.Error) { +func validateStopMethod(s string) (StopMethod, error) { sm := StopMethod(s) switch sm { case StopMethodPause, StopMethodStop, StopMethodKill: return sm, nil default: - return "", E.Invalid("stop_method", sm) + return "", errors.New("invalid stop method " + s) } } diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 23630a6..6fda047 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -42,7 +42,7 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr watcher, err := registerWatcher(providerSubTask, entry, waker) if err != nil { - return nil, err + return nil, E.Errorf("register watcher: %w", err) } if rp != nil { @@ -75,6 +75,9 @@ func (w *Watcher) Start(routeSubTask task.Task) E.Error { // Finish implements health.HealthMonitor. func (w *Watcher) Finish(reason any) { + if w.stream != nil { + w.stream.Close() + } } // Name implements health.HealthMonitor. diff --git a/internal/docker/idlewatcher/waker_http.go b/internal/docker/idlewatcher/waker_http.go index b7cf6c2..2634e5f 100644 --- a/internal/docker/idlewatcher/waker_http.go +++ b/internal/docker/idlewatcher/waker_http.go @@ -7,9 +7,7 @@ import ( "strconv" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -20,21 +18,26 @@ func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if !shouldNext { return } - w.rp.ServeHTTP(rw, r) + select { + case <-r.Context().Done(): + return + default: + w.rp.ServeHTTP(rw, r) + } } func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { w.resetIdleTimer() - if r.Body != nil { - defer r.Body.Close() - } - // pass through if container is already ready if w.ready.Load() { return true } + if r.Body != nil { + defer r.Body.Close() + } + accept := gphttp.GetAccept(r.Header) acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) @@ -49,23 +52,22 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN rw.Header().Add("Cache-Control", "must-revalidate") rw.Header().Add("Connection", "close") if _, err := rw.Write(body); err != nil { - w.l.Errorf("error writing http response: %s", err) + w.Err(err).Msg("error writing http response") } - return + return false } ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout")) defer cancel() - checkCanceled := func() bool { + checkCanceled := func() (canceled bool) { select { - case <-w.task.Context().Done(): - w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context())) - http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) - return true case <-ctx.Done(): - w.l.Debugf("wake canceled: %s", context.Cause(ctx)) - http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) + w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled") + return true + case <-w.task.Context().Done(): + w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled") + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) return true default: return false @@ -76,12 +78,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN return false } - w.l.Debug("wake signal received") + w.WakeTrace().Msg("signal received") err := w.wakeIfStopped() if err != nil { - w.l.Error(E.FailWith("wake", err)) + w.WakeError(err).Send() http.Error(rw, "Error waking container", http.StatusInternalServerError) - return + return false } for { @@ -92,11 +94,11 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN if w.Status() == health.StatusHealthy { w.resetIdleTimer() if isCheckRedirect { - logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL()) + w.Debug().Msgf("redirecting to %s ...", w.hc.URL()) rw.WriteHeader(http.StatusOK) - return + return false } - logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) + w.Debug().Msgf("passing through to %s ...", w.hc.URL()) return true } diff --git a/internal/docker/idlewatcher/waker_stream.go b/internal/docker/idlewatcher/waker_stream.go index 1ec4174..5cd648b 100644 --- a/internal/docker/idlewatcher/waker_stream.go +++ b/internal/docker/idlewatcher/waker_stream.go @@ -7,7 +7,7 @@ import ( "net" "time" - "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -16,25 +16,25 @@ func (w *Watcher) Addr() net.Addr { return w.stream.Addr() } +// Setup implements types.Stream. func (w *Watcher) Setup() error { return w.stream.Setup() } // Accept implements types.Stream. -func (w *Watcher) Accept() (conn net.Conn, err error) { +func (w *Watcher) Accept() (conn types.StreamConn, err error) { conn, err = w.stream.Accept() if err != nil { - logrus.Errorf("accept failed with error: %s", err) return } - if err := w.wakeFromStream(); err != nil { - w.l.Error(err) + if wakeErr := w.wakeFromStream(); wakeErr != nil { + w.WakeError(wakeErr).Msg("error waking from stream") } return } // Handle implements types.Stream. -func (w *Watcher) Handle(conn net.Conn) error { +func (w *Watcher) Handle(conn types.StreamConn) error { if err := w.wakeFromStream(); err != nil { return err } @@ -54,11 +54,11 @@ func (w *Watcher) wakeFromStream() error { return nil } - w.l.Debug("wake signal received") + w.WakeDebug().Msg("wake signal received") wakeErr := w.wakeIfStopped() if wakeErr != nil { - wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr) - w.l.Error(wakeErr) + wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr) + w.WakeError(wakeErr).Msg("wake failed") return wakeErr } @@ -69,18 +69,18 @@ func (w *Watcher) wakeFromStream() error { select { case <-w.task.Context().Done(): cause := w.task.FinishCause() - w.l.Debugf("wake canceled: %s", cause) + w.WakeDebug().Str("cause", cause.Error()).Msg("canceled") return cause case <-ctx.Done(): cause := context.Cause(ctx) - w.l.Debugf("wake canceled: %s", cause) + w.WakeDebug().Str("cause", cause.Error()).Msg("timeout") return cause default: } if w.Status() == health.StatusHealthy { w.resetIdleTimer() - logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) + w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String()) return nil } diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index e5388fc..d23502d 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -3,15 +3,15 @@ package idlewatcher import ( "context" "errors" - "fmt" "sync" "time" "github.com/docker/docker/api/types/container" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" D "github.com/yusing/go-proxy/internal/docker" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" @@ -25,6 +25,8 @@ type ( Watcher struct { _ U.NoCopy + zerolog.Logger + *idlewatcher.Config *waker @@ -32,7 +34,6 @@ type ( stopByMethod StopCallback // send a docker command w.r.t. `stop_method` ticker *time.Ticker task task.Task - l *logrus.Entry } WakeDone <-chan error @@ -44,13 +45,12 @@ var ( watcherMap = F.NewMapOf[string, *Watcher]() watcherMapMu sync.Mutex - logger = logrus.WithField("module", "idle_watcher") + logger = logging.With().Str("module", "idle_watcher").Logger() ) const dockerReqTimeout = 3 * time.Second -func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) { - failure := E.Failure("idle_watcher register") +func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) { cfg := entry.IdlewatcherConfig() if cfg.IdleTimeout == 0 { @@ -71,17 +71,17 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) } client, err := D.ConnectClient(cfg.DockerHost) - if err.HasError() { - return nil, failure.With(err) + if err != nil { + return nil, err } w := &Watcher{ + Logger: logger.With().Str("name", cfg.ContainerName).Logger(), Config: cfg, waker: waker, client: client, task: providerSubtask, ticker: time.NewTicker(cfg.IdleTimeout), - l: logger.WithField("container", cfg.ContainerName), } w.stopByMethod = w.getStopCallback() watcherMap.Store(key, w) @@ -99,6 +99,23 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) return w, nil } +// WakeDebug logs a debug message related to waking the container. +func (w *Watcher) WakeDebug() *zerolog.Event { + return w.Debug().Str("action", "wake") +} + +func (w *Watcher) WakeTrace() *zerolog.Event { + return w.Trace().Str("action", "wake") +} + +func (w *Watcher) WakeError(err error) *zerolog.Event { + return w.Err(err).Str("action", "wake") +} + +func (w *Watcher) LogReason(action, reason string) { + w.Info().Str("reason", reason).Msg(action) +} + func (w *Watcher) containerStop(ctx context.Context) error { return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{ Signal: string(w.StopSignal), @@ -130,7 +147,7 @@ func (w *Watcher) containerStatus() (string, error) { defer cancel() json, err := w.client.ContainerInspect(ctx, w.ContainerID) if err != nil { - return "", fmt.Errorf("failed to inspect container: %w", err) + return "", err } return json.State.Status, nil } @@ -181,7 +198,7 @@ func (w *Watcher) getStopCallback() StopCallback { } func (w *Watcher) resetIdleTimer() { - w.l.Trace("reset idle timer") + w.Trace().Msg("reset idle timer") w.ticker.Reset(w.IdleTimeout) } @@ -190,7 +207,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ Filters: W.NewDockerFilter( W.DockerFilterContainer, - W.DockerrFilterContainer(w.ContainerID), + W.DockerFilterContainerNameID(w.ContainerID), W.DockerFilterStart, W.DockerFilterStop, W.DockerFilterDie, @@ -214,7 +231,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas // // it exits only if the context is canceled, the container is destroyed, // errors occured on docker client, or route provider died (mainly caused by config reload). -func (w *Watcher) watchUntilDestroy() error { +func (w *Watcher) watchUntilDestroy() (returnCause error) { dockerWatcher := W.NewDockerWatcherWithClient(w.client) eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) defer eventTask.Finish("stopped") @@ -224,36 +241,36 @@ func (w *Watcher) watchUntilDestroy() error { case <-w.task.Context().Done(): return w.task.FinishCause() case err := <-dockerEventErrCh: - if err != nil && err.IsNot(context.Canceled) { - w.l.Error(E.FailWith("docker watcher", err)) - return err.Error() + if !err.Is(context.Canceled) { + E.LogError("idlewatcher error", err, &w.Logger) } + return err case e := <-dockerEventCh: switch { case e.Action == events.ActionContainerDestroy: w.ContainerRunning = false w.ready.Store(false) - w.l.Info("watcher stopped by container destruction") + w.LogReason("watcher stopped", "container destroyed") return errors.New("container destroyed") // create / start / unpause case e.Action.IsContainerWake(): w.ContainerRunning = true w.resetIdleTimer() - w.l.Info("container awaken") + w.Info().Msg("awaken") case e.Action.IsContainerSleep(): // stop / pause / kil w.ContainerRunning = false w.ready.Store(false) w.ticker.Stop() default: - w.l.Errorf("unexpected docker event: %s", e) + w.Error().Msg("unexpected docker event: " + e.String()) } // container name changed should also change the container id if w.ContainerName != e.ActorName { - w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName) + w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName) w.ContainerName = e.ActorName } if w.ContainerID != e.ActorID { - w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID) + w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID) w.ContainerID = e.ActorID // recreate event stream eventTask.Finish("recreate event stream") @@ -263,9 +280,9 @@ func (w *Watcher) watchUntilDestroy() error { w.ticker.Stop() if w.ContainerRunning { if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) { - w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err) + w.Err(err).Msgf("container stop with method %q failed", w.StopMethod) } else { - w.l.Info("container stopped by idle timeout") + w.LogReason("container stopped", "idle timeout") } } } diff --git a/internal/docker/inspect.go b/internal/docker/inspect.go index 7220dd8..3531af5 100644 --- a/internal/docker/inspect.go +++ b/internal/docker/inspect.go @@ -4,28 +4,26 @@ import ( "context" "errors" "time" - - E "github.com/yusing/go-proxy/internal/error" ) -func Inspect(dockerHost string, containerID string) (*Container, E.Error) { +func Inspect(dockerHost string, containerID string) (*Container, error) { client, err := ConnectClient(dockerHost) defer client.Close() - if err.HasError() { - return nil, E.FailWith("connect to docker", err) + if err != nil { + return nil, err } return client.Inspect(containerID) } -func (c Client) Inspect(containerID string) (*Container, E.Error) { +func (c Client) Inspect(containerID string) (*Container, error) { ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout")) defer cancel() json, err := c.ContainerInspect(ctx, containerID) if err != nil { - return nil, E.From(err) + return nil, err } return FromJSON(json, c.key), nil } diff --git a/internal/docker/label.go b/internal/docker/label.go index 0d3e8e4..0d13cd5 100644 --- a/internal/docker/label.go +++ b/internal/docker/label.go @@ -24,6 +24,11 @@ type ( NestedLabelMap map[string]U.SerializedObject ) +var ( + ErrApplyToNil = E.New("label value is nil") + ErrFieldNotExist = E.New("field does not exist") +) + func (l *Label) String() string { if l.Attribute == "" { return l.Namespace + "." + l.Target @@ -41,7 +46,7 @@ func (l *Label) String() string { // - error: an error if the field does not exist. func ApplyLabel[T any](obj *T, l *Label) E.Error { if obj == nil { - return E.Invalid("nil object", l) + return ErrApplyToNil.Subject(l.String()) } switch nestedLabel := l.Value.(type) { case *Label: @@ -54,7 +59,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error { } } if !field.IsValid() { - return E.NotExist("field", l.Attribute) + return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String()) } dst, ok := field.Interface().(NestedLabelMap) if !ok { @@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error { } else { field = field.Addr() } - return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface()) + err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface()) + if err != nil { + return err.Subject(l.String()) + } + return nil } if dst == nil { field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]())) @@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error { dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value return nil default: - return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj) + err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj) + if err != nil { + return err.Subject(l.String()) + } + return nil } } -func ParseLabel(label string, value string) (*Label, E.Error) { +func ParseLabel(label string, value string) *Label { parts := strings.Split(label, ".") if len(parts) < 2 { return &Label{ Namespace: label, Value: value, - }, nil + } } l := &Label{ @@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.Error) { l.Attribute = parts[2] default: l.Attribute = parts[2] - nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value) - if err != nil { - return nil, err - } + nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value) l.Value = nestedLabel } - return l, nil + return l } diff --git a/internal/docker/label_test.go b/internal/docker/label_test.go index 3f178ea..4591774 100644 --- a/internal/docker/label_test.go +++ b/internal/docker/label_test.go @@ -20,9 +20,8 @@ func makeLabel(ns, name, attr string) string { func TestNestedLabel(t *testing.T) { mAttr := "prop1" - pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) - ExpectNoError(t, err.Error()) - sGot := ExpectType[*Label](t, pl.Value) + lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) + sGot := ExpectType[*Label](t, lbl.Value) ExpectFalse(t, sGot == nil) ExpectEqual(t, sGot.Namespace, mName) ExpectEqual(t, sGot.Attribute, mAttr) @@ -32,10 +31,9 @@ func TestApplyNestedLabel(t *testing.T) { entry := new(struct { Middlewares NestedLabelMap `yaml:"middlewares"` }) - pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) - ExpectNoError(t, err.Error()) - err = ApplyLabel(entry, pl) - ExpectNoError(t, err.Error()) + lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) + err := ApplyLabel(entry, lbl) + ExpectNoError(t, err) middleware1, ok := entry.Middlewares[mName] ExpectTrue(t, ok) got := ExpectType[string](t, middleware1[mAttr]) @@ -52,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) { entry.Middlewares[mName] = make(U.SerializedObject) entry.Middlewares[mName][checkAttr] = checkV - pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) - ExpectNoError(t, err.Error()) - err = ApplyLabel(entry, pl) - ExpectNoError(t, err.Error()) + lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) + err := ApplyLabel(entry, lbl) + ExpectNoError(t, err) middleware1, ok := entry.Middlewares[mName] ExpectTrue(t, ok) got := ExpectType[string](t, middleware1[mAttr]) @@ -74,10 +71,9 @@ func TestApplyNestedLabelNoAttr(t *testing.T) { entry.Middlewares = make(NestedLabelMap) entry.Middlewares[mName] = make(U.SerializedObject) - pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v) - ExpectNoError(t, err.Error()) - err = ApplyLabel(entry, pl) - ExpectNoError(t, err.Error()) + lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v) + err := ApplyLabel(entry, lbl) + ExpectNoError(t, err) _, ok := entry.Middlewares[mName] ExpectTrue(t, ok) } diff --git a/internal/docker/list_containers.go b/internal/docker/list_containers.go index 285de52..ba9b96e 100644 --- a/internal/docker/list_containers.go +++ b/internal/docker/list_containers.go @@ -8,7 +8,6 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" - E "github.com/yusing/go-proxy/internal/error" ) var listOptions = container.ListOptions{ @@ -23,19 +22,19 @@ var listOptions = container.ListOptions{ All: true, } -func ListContainers(clientHost string) ([]types.Container, E.Error) { +func ListContainers(clientHost string) ([]types.Container, error) { dockerClient, err := ConnectClient(clientHost) - if err.HasError() { - return nil, E.FailWith("connect to docker", err) + if err != nil { + return nil, err } defer dockerClient.Close() ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout")) defer cancel() - containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions)) - if err.HasError() { - return nil, E.FailWith("list containers", err) + containers, err := dockerClient.ContainerList(ctx, listOptions) + if err != nil { + return nil, err } return containers, nil } diff --git a/internal/docker/logger.go b/internal/docker/logger.go index b86d1be..d27f2cb 100644 --- a/internal/docker/logger.go +++ b/internal/docker/logger.go @@ -1,5 +1,7 @@ package docker -import "github.com/sirupsen/logrus" +import ( + "github.com/yusing/go-proxy/internal/logging" +) -var logger = logrus.WithField("module", "docker") +var logger = logging.With().Str("module", "docker").Logger() diff --git a/internal/error/base.go b/internal/error/base.go new file mode 100644 index 0000000..d1287ab --- /dev/null +++ b/internal/error/base.go @@ -0,0 +1,46 @@ +package error + +import ( + "errors" + "fmt" +) + +// baseError is an immutable wrapper around an error. +type baseError struct { + Err error `json:"err"` +} + +func (err *baseError) Unwrap() error { + return err.Err +} + +func (err *baseError) Is(other error) bool { + if other, ok := other.(*baseError); ok { + return errors.Is(err.Err, other.Err) + } + return errors.Is(err.Err, other) +} + +func (err baseError) Subject(subject string) Error { + err.Err = PrependSubject(subject, err.Err) + return &err +} + +func (err *baseError) Subjectf(format string, args ...any) Error { + if len(args) > 0 { + return err.Subject(fmt.Sprintf(format, args...)) + } + return err.Subject(format) +} + +func (err baseError) With(extra error) Error { + return &nestedError{&err, []error{extra}} +} + +func (err baseError) Withf(format string, args ...any) Error { + return &nestedError{&err, []error{fmt.Errorf(format, args...)}} +} + +func (err *baseError) Error() string { + return err.Err.Error() +} diff --git a/internal/error/builder.go b/internal/error/builder.go index cdab20a..12aa474 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -1,104 +1,104 @@ package error import ( - "errors" "fmt" "sync" ) type Builder struct { - *builder -} - -type builder struct { - message string - errors []Error + about string + errs []error sync.Mutex } -func NewBuilder(format string, args ...any) Builder { - if len(args) > 0 { - return Builder{&builder{message: fmt.Sprintf(format, args...)}} +func NewBuilder(about string) *Builder { + return &Builder{about: about} +} + +func (b *Builder) About() string { + if !b.HasError() { + return "" } - return Builder{&builder{message: format}} + return b.about +} + +//go:inline +func (b *Builder) HasError() bool { + return len(b.errs) > 0 +} + +func (b *Builder) Error() Error { + if !b.HasError() { + return nil + } + if len(b.errs) == 1 { + return From(b.errs[0]) + } + return &nestedError{Err: New(b.about), Extras: b.errs} +} + +func (b *Builder) String() string { + if !b.HasError() { + return "" + } + return (&nestedError{Err: New(b.about), Extras: b.errs}).Error() } // Add adds an error to the Builder. // // adding nil is no-op, -// -// flatten is a boolean flag to flatten the NestedError. -func (b Builder) Add(err Error, flatten ...bool) { - if err != nil { - b.Lock() - if len(flatten) > 0 && flatten[0] { - for _, e := range err.extras { - b.errors = append(b.errors, &e) - } +func (b *Builder) Add(err error) *Builder { + if err == nil { + return b + } + + b.Lock() + defer b.Unlock() + + switch err := err.(type) { + case *baseError: + b.errs = append(b.errs, err.Err) + case *nestedError: + if err.Err == nil { + b.errs = append(b.errs, err.Extras...) } else { - b.errors = append(b.errors, err) + b.errs = append(b.errs, err) } - b.Unlock() - } -} - -func (b Builder) AddE(err error) { - b.Add(From(err)) -} - -func (b Builder) Addf(format string, args ...any) { - if len(args) > 0 { - b.Add(errorf(format, args...)) - } else { - b.AddE(errors.New(format)) - } -} - -func (b Builder) AddRange(errs ...Error) { - b.Lock() - defer b.Unlock() - for _, err := range errs { - b.errors = append(b.errors, err) - } -} - -func (b Builder) AddRangeE(errs ...error) { - b.Lock() - defer b.Unlock() - for _, err := range errs { - b.errors = append(b.errors, From(err)) - } -} - -// Build builds a NestedError based on the errors collected in the Builder. -// -// If there are no errors in the Builder, it returns a Nil() NestedError. -// Otherwise, it returns a NestedError with the message and the errors collected. -// -// Returns: -// - NestedError: the built NestedError. -func (b Builder) Build() Error { - if len(b.errors) == 0 { - return nil - } - return Join(b.message, b.errors...) -} - -func (b Builder) To(ptr *Error) { - switch { - case ptr == nil: - return - case *ptr == nil: - *ptr = b.Build() default: - (*ptr).extras = append((*ptr).extras, *b.Build()) + b.errs = append(b.errs, err) } + + return b } -func (b Builder) String() string { - return b.Build().String() +func (b *Builder) Adds(err string) *Builder { + b.Lock() + defer b.Unlock() + b.errs = append(b.errs, newError(err)) + return b } -func (b Builder) HasError() bool { - return len(b.errors) > 0 +func (b *Builder) Addf(format string, args ...any) *Builder { + if len(args) > 0 { + b.Lock() + defer b.Unlock() + b.errs = append(b.errs, fmt.Errorf(format, args...)) + } else { + b.Adds(format) + } + + return b +} + +func (b *Builder) AddRange(errs ...error) *Builder { + b.Lock() + defer b.Unlock() + + for _, err := range errs { + if err != nil { + b.errs = append(b.errs, err) + } + } + + return b } diff --git a/internal/error/builder_test.go b/internal/error/builder_test.go index 2c24948..ec22615 100644 --- a/internal/error/builder_test.go +++ b/internal/error/builder_test.go @@ -1,6 +1,9 @@ package error_test import ( + "context" + "errors" + "io" "testing" . "github.com/yusing/go-proxy/internal/error" @@ -8,14 +11,13 @@ import ( ) func TestBuilderEmpty(t *testing.T) { - eb := NewBuilder("qwer") - ExpectTrue(t, eb.Build() == nil) - ExpectTrue(t, eb.Build().NoError()) + eb := NewBuilder("foo") + ExpectTrue(t, errors.Is(eb.Error(), nil)) ExpectFalse(t, eb.HasError()) } func TestBuilderAddNil(t *testing.T) { - eb := NewBuilder("asdf") + eb := NewBuilder("foo") var err Error for range 3 { eb.Add(nil) @@ -23,41 +25,31 @@ func TestBuilderAddNil(t *testing.T) { for range 3 { eb.Add(err) } - ExpectTrue(t, eb.Build() == nil) - ExpectTrue(t, eb.Build().NoError()) + eb.AddRange(nil, nil, err) ExpectFalse(t, eb.HasError()) + ExpectTrue(t, eb.Error() == nil) +} + +func TestBuilderIs(t *testing.T) { + eb := NewBuilder("foo") + eb.Add(context.Canceled) + eb.Add(io.ErrShortBuffer) + ExpectTrue(t, eb.HasError()) + ExpectError(t, io.ErrShortBuffer, eb.Error()) + ExpectError(t, context.Canceled, eb.Error()) } func TestBuilderNested(t *testing.T) { - eb := NewBuilder("error occurred") - eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2"))) - eb.Add(Failure("Action 2").With(Invalid("Inner", "3"))) - - got := eb.Build().String() - expected1 := (`error occurred: - - Action 1 failed: - - invalid Inner: 1 - - invalid Inner: 2 - - Action 2 failed: - - invalid Inner: 3`) - expected2 := (`error occurred: - - Action 1 failed: - - invalid Inner: "1" - - invalid Inner: "2" - - Action 2 failed: - - invalid Inner: "3"`) - ExpectEqualAny(t, got, []string{expected1, expected2}) -} - -func TestBuilderTo(t *testing.T) { - eb := NewBuilder("error occurred") - eb.Addf("abcd") - - var err Error - eb.To(&err) - got := err.String() - expected := (`error occurred: - - abcd`) + eb := NewBuilder("action failed") + eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2")) + eb.Add(New("Action 2").Withf("Inner: 3")) + got := eb.String() + expected := `action failed + • Action 1 + • Inner: 1 + • Inner: 2 + • Action 2 + • Inner: 3` ExpectEqual(t, got, expected) } diff --git a/internal/error/error.go b/internal/error/error.go index 39e7d60..e584653 100644 --- a/internal/error/error.go +++ b/internal/error/error.go @@ -1,317 +1,31 @@ package error -import ( - "encoding/json" - "errors" - "fmt" - "strings" -) +type Error interface { + error -type ( - Error = *ErrorImpl - ErrorImpl struct { - subject string - err error - extras []ErrorImpl - } - ErrorJSONMarshaller struct { - Subject string `json:"subject"` - Err string `json:"error"` - Extras []ErrorJSONMarshaller `json:"extras,omitempty"` - } -) - -func From(err error) Error { - if IsNil(err) { - return nil - } - return &ErrorImpl{err: err} + // Is is a wrapper for errors.Is when there is no sub-error. + // + // When there are sub-errors, they will also be checked. + Is(other error) bool + // With appends a sub-error to the error. + With(extra error) Error + // Withf is a wrapper for With(fmt.Errorf(format, args...)). + Withf(format string, args ...any) Error + // Subject prepends the given subject with a colon and space to the error message. + // + // If there is already a subject in the error message, the subject will be + // prepended to the existing subject with " > ". + // + // Subject empty string is ignored. + Subject(subject string) Error + // Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)). + Subjectf(format string, args ...any) Error } -func FromJSON(data []byte) (Error, bool) { - var j ErrorJSONMarshaller - if err := json.Unmarshal(data, &j); err != nil { - return nil, false - } - if j.Err == "" { - return nil, false - } - extras := make([]ErrorImpl, len(j.Extras)) - for i, e := range j.Extras { - extra, ok := fromJSONObject(e) - if !ok { - return nil, false - } - extras[i] = *extra - } - return &ErrorImpl{ - subject: j.Subject, - err: errors.New(j.Err), - extras: extras, - }, true -} - -func TryUnwrap(err error) error { - if err == nil { - return nil - } - if unwrapped := errors.Unwrap(err); unwrapped != nil { - return unwrapped - } - return err -} - -// Check is a helper function that -// convert (T, error) to (T, NestedError). -func Check[T any](obj T, err error) (T, Error) { - return obj, From(err) -} - -func Join(message string, err ...Error) Error { - extras := make([]ErrorImpl, len(err)) - nErr := 0 - for i, e := range err { - if e == nil { - continue - } - extras[i] = *e - nErr++ - } - if nErr == 0 { - return nil - } - return &ErrorImpl{ - err: errors.New(message), - extras: extras, - } -} - -func JoinE(message string, err ...error) Error { - b := NewBuilder("%s", message) - for _, e := range err { - b.AddE(e) - } - return b.Build() -} - -func IsNil(err error) bool { - return err == nil -} - -func IsNotNil(err error) bool { - return err != nil -} - -func (ne Error) String() string { - var buf strings.Builder - ne.writeToSB(&buf, 0, "") - return buf.String() -} - -func (ne Error) Is(err error) bool { - if ne == nil { - return err == nil - } - // return errors.Is(ne.err, err) - if errors.Is(ne.err, err) { - return true - } - for _, e := range ne.extras { - if e.Is(err) { - return true - } - } - return false -} - -func (ne Error) IsNot(err error) bool { - return !ne.Is(err) -} - -func (ne Error) Error() error { - if ne == nil { - return nil - } - return ne.buildError(0, "") -} - -func (ne Error) With(s any) Error { - if ne == nil { - return ne - } - var msg string - switch ss := s.(type) { - case nil: - return ne - case *ErrorImpl: - if len(ss.extras) == 1 { - ne.extras = append(ne.extras, ss.extras[0]) - return ne - } - return ne.withError(ss) - case error: - // unwrap only once - return ne.withError(From(TryUnwrap(ss))) - case string: - msg = ss - case fmt.Stringer: - return ne.appendMsg(ss.String()) - default: - return ne.appendMsg(fmt.Sprint(s)) - } - return ne.withError(From(errors.New(msg))) -} - -func (ne Error) Extraf(format string, args ...any) Error { - return ne.With(errorf(format, args...)) -} - -func (ne Error) Subject(s any, sep ...string) Error { - if ne == nil { - return ne - } - var subject string - switch ss := s.(type) { - case string: - subject = ss - case fmt.Stringer: - subject = ss.String() - default: - subject = fmt.Sprint(s) - } - switch { - case ne.subject == "": - ne.subject = subject - case len(sep) > 0: - ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject) - default: - ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject) - } - return ne -} - -func (ne Error) Subjectf(format string, args ...any) Error { - if ne == nil { - return ne - } - return ne.Subject(fmt.Sprintf(format, args...)) -} - -func (ne Error) JSONObject() ErrorJSONMarshaller { - extras := make([]ErrorJSONMarshaller, len(ne.extras)) - for i, e := range ne.extras { - extras[i] = e.JSONObject() - } - return ErrorJSONMarshaller{ - Subject: ne.subject, - Err: ne.err.Error(), - Extras: extras, - } -} - -func (ne Error) JSON() []byte { - b, err := json.MarshalIndent(ne.JSONObject(), "", " ") - if err != nil { - panic(err) - } - return b -} - -func (ne Error) NoError() bool { - return ne == nil -} - -func (ne Error) HasError() bool { - return ne != nil -} - -func errorf(format string, args ...any) Error { - for i, arg := range args { - if err, ok := arg.(error); ok { - if unwrapped := errors.Unwrap(err); unwrapped != nil { - args[i] = unwrapped - } - } - } - return From(fmt.Errorf(format, args...)) -} - -func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) { - data, err := json.Marshal(obj) - if err != nil { - return nil, false - } - return FromJSON(data) -} - -func (ne Error) withError(err Error) Error { - if ne != nil && err != nil { - ne.extras = append(ne.extras, *err) - } - return ne -} - -func (ne Error) appendMsg(msg string) Error { - if ne == nil { - return nil - } - ne.err = fmt.Errorf("%w %s", ne.err, msg) - return ne -} - -func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) { - for range level { - sb.WriteString(" ") - } - sb.WriteString(prefix) - - if ne.NoError() { - sb.WriteString("nil") - return - } - - if ne.subject != "" { - sb.WriteString(ne.subject) - sb.WriteRune(' ') - } - sb.WriteString(ne.err.Error()) - if len(ne.extras) > 0 { - sb.WriteRune(':') - for _, extra := range ne.extras { - sb.WriteRune('\n') - extra.writeToSB(sb, level+1, "- ") - } - } -} - -func (ne Error) buildError(level int, prefix string) error { - var res error - var sb strings.Builder - - for range level { - sb.WriteString(" ") - } - sb.WriteString(prefix) - - if ne.NoError() { - sb.WriteString("nil") - return errors.New(sb.String()) - } - - res = fmt.Errorf("%s%w", sb.String(), ne.err) - sb.Reset() - - if ne.subject != "" { - sb.WriteString(fmt.Sprintf(" for %q", ne.subject)) - } - if len(ne.extras) > 0 { - sb.WriteRune(':') - res = fmt.Errorf("%w%s", res, sb.String()) - for _, extra := range ne.extras { - res = errors.Join(res, extra.buildError(level+1, "- ")) - } - } else { - res = fmt.Errorf("%w%s", res, sb.String()) - } - return res +// this makes JSON marshalling work, +// as the builtin one doesn't. +type errStr string + +func (err errStr) Error() string { + return string(err) } diff --git a/internal/error/error_test.go b/internal/error/error_test.go index 5d0b9ea..0ac2902 100644 --- a/internal/error/error_test.go +++ b/internal/error/error_test.go @@ -1,107 +1,157 @@ -package error_test +package error import ( "errors" + "strings" "testing" - . "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) +func TestBaseString(t *testing.T) { + ExpectEqual(t, New("error").Error(), "error") +} + +func TestBaseWithSubject(t *testing.T) { + err := New("error") + withSubject := err.Subject("foo") + withSubjectf := err.Subjectf("%s %s", "foo", "bar") + + ExpectError(t, err, withSubject) + ExpectStrEqual(t, withSubject.Error(), "foo: error") + ExpectTrue(t, withSubject.Is(err)) + + ExpectError(t, err, withSubjectf) + ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error") + ExpectTrue(t, withSubjectf.Is(err)) +} + +func TestBaseWithExtra(t *testing.T) { + err := New("error") + extra := New("bar").Subject("baz") + withExtra := err.With(extra) + + ExpectTrue(t, withExtra.Is(extra)) + ExpectTrue(t, withExtra.Is(err)) + + ExpectTrue(t, errors.Is(withExtra, extra)) + ExpectTrue(t, errors.Is(withExtra, err)) + + ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error())) + ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error())) + ExpectTrue(t, strings.Contains(withExtra.Error(), "baz")) +} + +func TestBaseUnwrap(t *testing.T) { + err := errors.New("err") + wrapped := From(err) + + ExpectError(t, err, errors.Unwrap(wrapped)) +} + +func TestNestedUnwrap(t *testing.T) { + err := errors.New("err") + err2 := New("err2") + wrapped := From(err).Subject("foo").With(err2.Subject("bar")) + + unwrapper, ok := wrapped.(interface{ Unwrap() []error }) + ExpectTrue(t, ok) + + ExpectError(t, err, wrapped) + ExpectError(t, err2, wrapped) + ExpectEqual(t, len(unwrapper.Unwrap()), 2) +} + func TestErrorIs(t *testing.T) { - ExpectTrue(t, Failure("foo").Is(ErrFailure)) - ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure)) - ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid)) - ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid)) + from := errors.New("error") + err := From(from) + ExpectError(t, from, err) - ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid)) - ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure)) + ExpectTrue(t, err.Is(from)) + ExpectFalse(t, err.Is(New("error"))) - ExpectFalse(t, Invalid("foo", "bar").Is(nil)) - - ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure)) - ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid)) - ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure)) - ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists)) + ExpectTrue(t, errors.Is(err.Subject("foo"), from)) + ExpectTrue(t, errors.Is(err.Withf("foo"), from)) + ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from)) } -func TestErrorNestedIs(t *testing.T) { - var err Error - ExpectTrue(t, err.Is(nil)) +func TestErrorImmutability(t *testing.T) { + err := New("err") + err2 := New("err2") - err = Failure("some reason") - ExpectTrue(t, err.Is(ErrFailure)) - ExpectFalse(t, err.Is(ErrDuplicated)) + for range 3 { + // t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err) + err.Subject("foo") + ExpectFalse(t, strings.Contains(err.Error(), "foo")) - err.With(Duplicated("something", "")) - ExpectTrue(t, err.Is(ErrFailure)) - ExpectTrue(t, err.Is(ErrDuplicated)) - ExpectFalse(t, err.Is(ErrInvalid)) -} + err.With(err2) + ExpectFalse(t, strings.Contains(err.Error(), "extra")) + ExpectFalse(t, err.Is(err2)) -func TestIsNil(t *testing.T) { - var err Error - ExpectTrue(t, err.Is(nil)) - ExpectTrue(t, err == nil) - ExpectTrue(t, err.NoError()) - - eb := NewBuilder("") - returnNil := func() error { - return eb.Build().Error() + err = err.Subject("bar").Withf("baz") + ExpectTrue(t, err != nil) } - ExpectTrue(t, IsNil(returnNil())) - ExpectTrue(t, returnNil() == nil) - - ExpectTrue(t, (err. - Subject("any"). - With("something"). - Extraf("foo %s", "bar")) == nil) -} - -func TestErrorSimple(t *testing.T) { - ne := Failure("foo bar") - ExpectEqual(t, ne.String(), "foo bar failed") - ne = ne.Subject("baz") - ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"") } func TestErrorWith(t *testing.T) { - ne := Failure("foo").With("bar").With("baz") - ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz") + err1 := New("err1") + err2 := New("err2") + + err3 := err1.With(err2) + + ExpectTrue(t, err3.Is(err1)) + ExpectTrue(t, err3.Is(err2)) + + err2.Subject("foo") + + ExpectTrue(t, err3.Is(err1)) + ExpectTrue(t, err3.Is(err2)) + + // check if err3 is affected by err2.Subject + ExpectFalse(t, strings.Contains(err3.Error(), "foo")) } -func TestErrorNested(t *testing.T) { - inner := Failure("inner"). - With("1"). - With("1") - inner2 := Failure("inner2"). +func TestErrorStringSimple(t *testing.T) { + errFailure := New("generic failure") + ne := errFailure.Subject("foo bar") + ExpectStrEqual(t, ne.Error(), "foo bar: generic failure") + ne = ne.Subject("baz") + ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure") +} + +func TestErrorStringNested(t *testing.T) { + errFailure := New("generic failure") + inner := errFailure.Subject("inner"). + Withf("1"). + Withf("1") + inner2 := errFailure.Subject("inner2"). Subject("action 2"). - With("2"). - With("2") - inner3 := Failure("inner3"). + Withf("2"). + Withf("2") + inner3 := errFailure.Subject("inner3"). Subject("action 3"). - With("3"). - With("3") - ne := Failure("foo"). - With("bar"). - With("baz"). + Withf("3"). + Withf("3") + ne := errFailure. + Subject("foo"). + Withf("bar"). + Withf("baz"). With(inner). With(inner.With(inner2.With(inner3))) - want := `foo failed: - - bar - - baz - - inner failed: - - 1 - - 1 - - inner failed: - - 1 - - 1 - - inner2 failed for "action 2": - - 2 - - 2 - - inner3 failed for "action 3": - - 3 - - 3` - ExpectEqual(t, ne.String(), want) - ExpectEqual(t, ne.Error().Error(), want) + want := `foo: generic failure + • bar + • baz + • inner: generic failure + • 1 + • 1 + • inner: generic failure + • 1 + • 1 + • action 2 > inner2: generic failure + • 2 + • 2 + • action 3 > inner3: generic failure + • 3 + • 3` + ExpectStrEqual(t, ne.Error(), want) } diff --git a/internal/error/errors.go b/internal/error/errors.go deleted file mode 100644 index 4ae9214..0000000 --- a/internal/error/errors.go +++ /dev/null @@ -1,83 +0,0 @@ -package error - -import ( - stderrors "errors" - "fmt" - "reflect" -) - -var ( - ErrFailure = stderrors.New("failed") - ErrInvalid = stderrors.New("invalid") - ErrUnsupported = stderrors.New("unsupported") - ErrUnexpected = stderrors.New("unexpected") - ErrNotExists = stderrors.New("does not exist") - ErrMissing = stderrors.New("missing") - ErrDuplicated = stderrors.New("duplicated") - ErrOutOfRange = stderrors.New("out of range") - ErrTypeError = stderrors.New("type error") - ErrTypeMismatch = stderrors.New("type mismatch") - ErrPanicRecv = stderrors.New("panic recovered from") -) - -const fmtSubjectWhat = "%w %v: %q" - -func Failure(what string) Error { - return errorf("%s %w", what, ErrFailure) -} - -func FailedWhy(what string, why string) Error { - return Failure(what).With(why) -} - -func FailWith(what string, err any) Error { - return Failure(what).With(err) -} - -func Invalid(subject, what any) Error { - return errorf(fmtSubjectWhat, ErrInvalid, subject, what) -} - -func Unsupported(subject, what any) Error { - return errorf(fmtSubjectWhat, ErrUnsupported, subject, what) -} - -func Unexpected(subject, what any) Error { - return errorf(fmtSubjectWhat, ErrUnexpected, subject, what) -} - -func UnexpectedError(err error) Error { - return errorf("%w error: %w", ErrUnexpected, err) -} - -func NotExist(subject, what any) Error { - return errorf("%v %w: %v", subject, ErrNotExists, what) -} - -func Missing(subject any) Error { - return errorf("%w %v", ErrMissing, subject) -} - -func Duplicated(subject, what any) Error { - return errorf("%w %v: %v", ErrDuplicated, subject, what) -} - -func OutOfRange(subject any, value any) Error { - return errorf("%v %w: %v", subject, ErrOutOfRange, value) -} - -func TypeError(subject any, from, to reflect.Type) Error { - return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to) -} - -func TypeError2(subject any, from, to reflect.Value) Error { - return TypeError(subject, from.Type(), to.Type()) -} - -func TypeMismatch[Expect any](value any) Error { - return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) -} - -func PanicRecv(format string, args ...any) Error { - return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...)) -} diff --git a/internal/error/log.go b/internal/error/log.go new file mode 100644 index 0000000..56990a8 --- /dev/null +++ b/internal/error/log.go @@ -0,0 +1,43 @@ +package error + +import ( + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/logging" +) + +func getLogger(logger ...*zerolog.Logger) *zerolog.Logger { + if len(logger) > 0 { + return logger[0] + } + return logging.GetLogger() +} + +//go:inline +func LogFatal(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Fatal().Msg(err.Error()) +} + +//go:inline +func LogError(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Error().Msg(err.Error()) +} + +//go:inline +func LogWarn(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Warn().Msg(err.Error()) +} + +//go:inline +func LogPanic(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Panic().Msg(err.Error()) +} + +//go:inline +func LogInfo(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Info().Msg(err.Error()) +} + +//go:inline +func LogDebug(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Debug().Msg(err.Error()) +} diff --git a/internal/error/nested_error.go b/internal/error/nested_error.go new file mode 100644 index 0000000..a2fdc61 --- /dev/null +++ b/internal/error/nested_error.go @@ -0,0 +1,120 @@ +package error + +import ( + "errors" + "fmt" + "strings" +) + +type nestedError struct { + Err error `json:"err"` + Extras []error `json:"extras"` +} + +func (err nestedError) Subject(subject string) Error { + if err.Err == nil { + err.Err = newError(subject) + } else { + err.Err = PrependSubject(subject, err.Err) + } + return &err +} + +func (err *nestedError) Subjectf(format string, args ...any) Error { + if len(args) > 0 { + return err.Subject(fmt.Sprintf(format, args...)) + } + return err.Subject(format) +} + +func (err nestedError) With(extra error) Error { + if extra != nil { + err.Extras = append(err.Extras, extra) + } + return &err +} + +func (err nestedError) Withf(format string, args ...any) Error { + if len(args) > 0 { + err.Extras = append(err.Extras, fmt.Errorf(format, args...)) + } else { + err.Extras = append(err.Extras, newError(format)) + } + return &err +} + +func (err *nestedError) Unwrap() []error { + if err.Err == nil { + if len(err.Extras) == 0 { + return nil + } + return err.Extras + } + return append([]error{err.Err}, err.Extras...) +} + +func (err *nestedError) Is(other error) bool { + if errors.Is(err.Err, other) { + return true + } + for _, e := range err.Extras { + if errors.Is(e, other) { + return true + } + } + return false +} + +func (err *nestedError) Error() string { + return buildError(err, 0) +} + +//go:inline +func makeLine(err string, level int) string { + const bulletPrefix = "• " + const spaces = " " + + if level == 0 { + return err + } + return spaces[:2*level] + bulletPrefix + err +} + +func makeLines(errs []error, level int) []string { + if len(errs) == 0 { + return nil + } + lines := make([]string, 0, len(errs)) + for _, err := range errs { + switch err := err.(type) { + case *nestedError: + if err.Err != nil { + lines = append(lines, makeLine(err.Err.Error(), level)) + } + if extras := makeLines(err.Extras, level+1); len(extras) > 0 { + lines = append(lines, extras...) + } + default: + lines = append(lines, makeLine(err.Error(), level)) + } + } + return lines +} + +func buildError(err error, level int) string { + switch err := err.(type) { + case nil: + return makeLine("", level) + case *nestedError: + lines := make([]string, 0, 1+len(err.Extras)) + if err.Err != nil { + lines = append(lines, makeLine(err.Err.Error(), level)) + } + if extras := makeLines(err.Extras, level+1); len(extras) > 0 { + lines = append(lines, extras...) + } + return strings.Join(lines, "\n") + default: + return makeLine(err.Error(), level) + } +} diff --git a/internal/error/subject.go b/internal/error/subject.go new file mode 100644 index 0000000..c78ef2d --- /dev/null +++ b/internal/error/subject.go @@ -0,0 +1,50 @@ +package error + +import ( + "strings" + + "github.com/yusing/go-proxy/internal/utils/strutils/ansi" +) + +type withSubject struct { + Subject string `json:"subject"` + Err error `json:"err"` +} + +const subjectSep = " > " + +func highlight(subject string) string { + return ansi.HighlightRed + subject + ansi.Reset +} + +func PrependSubject(subject string, err error) *withSubject { + switch err := err.(type) { + case *withSubject: + return err.Prepend(subject) + case *baseError: + return PrependSubject(subject, err.Err) + default: + return &withSubject{subject, err} + } +} + +func (err withSubject) Prepend(subject string) *withSubject { + if subject != "" { + err.Subject = subject + subjectSep + err.Subject + } + return &err +} + +func (err *withSubject) Is(other error) bool { + return err.Err == other +} + +func (err *withSubject) Unwrap() error { + return err.Err +} + +func (err *withSubject) Error() string { + subjects := strings.Split(err.Subject, subjectSep) + subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1]) + return strings.Join(subjects, subjectSep) + ": " + err.Err.Error() +} diff --git a/internal/error/utils.go b/internal/error/utils.go new file mode 100644 index 0000000..6bd6c73 --- /dev/null +++ b/internal/error/utils.go @@ -0,0 +1,71 @@ +package error + +import ( + "errors" + "fmt" +) + +var ErrInvalidErrorJson = errors.New("invalid error json") + +func newError(message string) error { + return errStr(message) +} + +func New(message string) Error { + if message == "" { + return nil + } + return &baseError{newError(message)} +} + +func Errorf(format string, args ...any) Error { + return &baseError{fmt.Errorf(format, args...)} +} + +func From(err error) Error { + if err == nil { + return nil + } + if err, ok := err.(Error); ok { + return err + } + return &baseError{err} +} + +func Must[T any](v T, err error) T { + if err != nil { + LogPanic("must failed", err) + } + return v +} + +func Join(errors ...error) Error { + n := 0 + for _, err := range errors { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + errs := make([]error, 0, n) + for _, err := range errors { + if err != nil { + errs = append(errs, err) + } + } + return &nestedError{Extras: errs} +} + +func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T { + result, err := fn(arg) + eb.Add(err) + return result +} + +func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T { + result, err := fn(arg1, arg2) + eb.Add(err) + return result +} diff --git a/internal/list-icons.go b/internal/list-icons.go index 1764901..8aff4bc 100644 --- a/internal/list-icons.go +++ b/internal/list-icons.go @@ -53,7 +53,7 @@ func ListAvailableIcons() ([]string, error) { icons = append(icons, content.Path) } } - err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error() + err = utils.SaveJSON(iconsCachePath, &icons, 0o644) if err != nil { log.Print("error saving cache", err) } diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..664ea4a --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,69 @@ +package logging + +import ( + "os" + "strings" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/common" +) + +var logger zerolog.Logger + +func init() { + var timeFmt string + var level zerolog.Level + var exclude []string + + if common.IsTrace { + timeFmt = "04:05" + level = zerolog.TraceLevel + } else if common.IsDebug { + timeFmt = "01-02 15:04" + level = zerolog.DebugLevel + } else { + timeFmt = "01-02 15:04" + level = zerolog.InfoLevel + exclude = []string{"module"} + } + + prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces + prefix := strings.Repeat(" ", prefixLength) + + logger = zerolog.New( + zerolog.ConsoleWriter{ + Out: os.Stderr, + TimeFormat: timeFmt, + FieldsExclude: exclude, + FormatMessage: func(msgI interface{}) string { // pad spaces for each line + msg := msgI.(string) + lines := strings.Split(msg, "\n") + if len(lines) == 1 { + return msg + } + for i := 1; i < len(lines); i++ { + lines[i] = prefix + lines[i] + } + return strings.Join(lines, "\n") + }, + }, + ).Level(level).With().Timestamp().Logger() +} + +func DiscardLogger() { logger = zerolog.Nop() } + +func AddHook(h zerolog.Hook) { logger = logger.Hook(h) } + +func GetLogger() *zerolog.Logger { return &logger } +func With() zerolog.Context { return logger.With() } + +func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) } + +func Info() *zerolog.Event { return logger.Info() } +func Warn() *zerolog.Event { return logger.Warn() } +func Error() *zerolog.Event { return logger.Error() } +func Err(err error) *zerolog.Event { return logger.Err(err) } +func Debug() *zerolog.Event { return logger.Debug() } +func Fatal() *zerolog.Event { return logger.Fatal() } +func Panic() *zerolog.Event { return logger.Panic() } +func Trace() *zerolog.Event { return logger.Trace() } diff --git a/internal/net/http/dummy_response_writer.go b/internal/net/http/dummy_response_writer.go new file mode 100644 index 0000000..0e5a1a9 --- /dev/null +++ b/internal/net/http/dummy_response_writer.go @@ -0,0 +1,15 @@ +package http + +import "net/http" + +type DummyResponseWriter struct{} + +func (w DummyResponseWriter) Header() http.Header { + return make(http.Header) +} + +func (w DummyResponseWriter) Write([]byte) (_ int, _ error) { + return +} + +func (w DummyResponseWriter) WriteHeader(int) {} diff --git a/internal/net/http/loadbalancer/dummy_response_writer.go b/internal/net/http/loadbalancer/dummy_response_writer.go deleted file mode 100644 index d6ea9f0..0000000 --- a/internal/net/http/loadbalancer/dummy_response_writer.go +++ /dev/null @@ -1,15 +0,0 @@ -package loadbalancer - -import "net/http" - -type DummyResponseWriter struct{} - -func (w *DummyResponseWriter) Header() (_ http.Header) { - return -} - -func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) { - return -} - -func (w *DummyResponseWriter) WriteHeader(int) {} diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index 62e1a51..f952932 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -11,20 +11,22 @@ import ( ) type ipHash struct { + *LoadBalancer + realIP *middleware.Middleware pool servers mu sync.Mutex } func (lb *LoadBalancer) newIPHash() impl { - impl := new(ipHash) + impl := &ipHash{LoadBalancer: lb} if len(lb.Options) == 0 { return impl } var err E.Error impl.realIP, err = middleware.NewRealIP(lb.Options) if err != nil { - logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err) + E.LogError("invalid real_ip options, ignoring", err, &impl.Logger) } return impl } @@ -70,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { http.Error(rw, "Internal error", http.StatusInternalServerError) - logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err) + impl.Err(err).Msg("invalid remote address " + r.RemoteAddr) return } idx := hashIP(ip) % uint32(len(impl.pool)) diff --git a/internal/net/http/loadbalancer/least_conn.go b/internal/net/http/loadbalancer/least_conn.go index 2ca9794..8fe8894 100644 --- a/internal/net/http/loadbalancer/least_conn.go +++ b/internal/net/http/loadbalancer/least_conn.go @@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.R srv := srvs[0] minConn, ok := impl.nConn.Load(srv) if !ok { - logger.Errorf("[BUG] server %s not found", srv.Name) + impl.Error().Msgf("[BUG] server %s not found", srv.Name) http.Error(rw, "Internal error", http.StatusInternalServerError) } for i := 1; i < len(srvs); i++ { nConn, ok := impl.nConn.Load(srvs[i]) if !ok { - logger.Errorf("[BUG] server %s not found", srv.Name) + impl.Error().Msgf("[BUG] server %s not found", srv.Name) http.Error(rw, "Internal error", http.StatusInternalServerError) } if nConn.Load() < minConn.Load() { diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index ed860e6..4812116 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -6,9 +6,11 @@ import ( "sync" "time" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/health" @@ -29,6 +31,8 @@ type ( Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"` } LoadBalancer struct { + zerolog.Logger + impl *Config @@ -48,6 +52,7 @@ const maxWeight weightType = 100 func New(cfg *Config) *LoadBalancer { lb := &LoadBalancer{ + Logger: logger.With().Str("name", cfg.Link).Logger(), Config: new(Config), pool: newPool(), task: task.DummyTask(), @@ -102,7 +107,7 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { if lb.Mode == Unset && cfg.Mode != Unset { lb.Mode = cfg.Mode if !lb.Mode.ValidateUpdate() { - logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode) + lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode) } lb.updateImpl() } @@ -131,7 +136,11 @@ func (lb *LoadBalancer) AddServer(srv *Server) { lb.rebalance() lb.impl.OnAddServer(srv) - logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size()) + + lb.Debug(). + Str("action", "add"). + Str("server", srv.Name). + Msgf("%d servers available", lb.pool.Size()) } func (lb *LoadBalancer) RemoveServer(srv *Server) { @@ -148,13 +157,15 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { lb.rebalance() lb.impl.OnRemoveServer(srv) + lb.Debug(). + Str("action", "remove"). + Str("server", srv.Name). + Msgf("%d servers left", lb.pool.Size()) + if lb.pool.Size() == 0 { lb.task.Finish("no server left") - logger.Infof("loadbalancer %s stopped", lb.Link) return } - - logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size()) } func (lb *LoadBalancer) rebalance() { @@ -218,7 +229,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) defer cancel() // send dummy request to wake all servers - var dummyRW *DummyResponseWriter + var dummyRW gphttp.DummyResponseWriter for _, srv := range srvs { // wake only if server implements Waker _, ok := srv.handler.(idlewatcher.Waker) diff --git a/internal/net/http/loadbalancer/logger.go b/internal/net/http/loadbalancer/logger.go index 7b9b51d..30fac46 100644 --- a/internal/net/http/loadbalancer/logger.go +++ b/internal/net/http/loadbalancer/logger.go @@ -1,5 +1,5 @@ package loadbalancer -import "github.com/sirupsen/logrus" +import "github.com/yusing/go-proxy/internal/logging" -var logger = logrus.WithField("module", "load_balancer") +var logger = logging.With().Str("module", "load_balancer").Logger() diff --git a/internal/net/http/logger.go b/internal/net/http/logger.go new file mode 100644 index 0000000..3e52c93 --- /dev/null +++ b/internal/net/http/logger.go @@ -0,0 +1,5 @@ +package http + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "http").Logger() diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index 5702851..ec19869 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -15,9 +15,9 @@ type cidrWhitelist struct { } type cidrWhitelistOpts struct { - Allow []*types.CIDR - StatusCode int - Message string + Allow []*types.CIDR `json:"allow"` + StatusCode int `json:"statusCode"` + Message string `json:"message"` cachedAddr F.Map[string, bool] // cache for trusted IPs } @@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { return nil, err } if len(wl.cidrWhitelistOpts.Allow) == 0 { - return nil, E.Missing("allow range") + return nil, E.New("no allowed CIDRs") } return wl.m, nil } diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index 0daeb9d..dd5fc69 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -5,6 +5,7 @@ import ( "net/http" "testing" + E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte var deny, accept *Middleware func TestCIDRWhitelist(t *testing.T) { - mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose) - if err != nil { - panic(err) - } + errs := E.NewBuilder("") + mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) + ExpectNoError(t, errs.Error()) deny = mids["deny@file"] accept = mids["accept@file"] if deny == nil || accept == nil { @@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) { t.Run("deny", func(t *testing.T) { for range 10 { result, err := newMiddlewareTest(deny, nil) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode) ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message) } @@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) { t.Run("accept", func(t *testing.T) { for range 10 { result, err := newMiddlewareTest(accept, nil) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusOK) } }) diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index 20e7ff3..f2c5b92 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -10,10 +10,10 @@ import ( "sync" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/strutils" ) const ( @@ -26,7 +26,7 @@ const ( var ( cfCIDRsLastUpdate time.Time cfCIDRsMu sync.Mutex - cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP") + cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger() ) var CloudflareRealIP = &realIP{ @@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { ) if err != nil { cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) - cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval) + cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) return nil } } cfCIDRsLastUpdate = time.Now() - cfCIDRsLogger.Debugf("cloudflare CIDR range updated") + cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated") return } diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index f06b686..6b764e2 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -8,24 +8,31 @@ import ( "strconv" "strings" - "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/api/v1/errorpage" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" ) -var CustomErrorPage = &Middleware{ - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - if !ServeStaticErrorPageFile(w, r) { - next(w, r) - } - }, - modifyResponse: func(resp *Response) error { +var CustomErrorPage *Middleware + +func init() { + CustomErrorPage = customErrorPage() +} + +func customErrorPage() *Middleware { + m := &Middleware{ + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { + if !ServeStaticErrorPageFile(w, r) { + next(w, r) + } + }, + } + m.modifyResponse = func(resp *Response) error { // only handles non-success status code and html/plain content type contentType := gphttp.GetContentType(resp.Header) if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) if ok { - errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode) + CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode) /* trunk-ignore(golangci-lint/errcheck) */ io.Copy(io.Discard, resp.Body) // drain the original body resp.Body.Close() @@ -34,12 +41,13 @@ var CustomErrorPage = &Middleware{ resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage))) resp.Header.Set("Content-Type", "text/html; charset=utf-8") } else { - errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode) + CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } return nil } return nil - }, + } + return m } func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { @@ -51,7 +59,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { filename := path[len(gphttp.StaticFilePathPrefix):] file, ok := errorpage.GetStaticFile(filename) if !ok { - errPageLogger.Errorf("unable to load resource %s", filename) + logger.Error().Msg("unable to load resource " + filename) return false } ext := filepath.Ext(filename) @@ -63,15 +71,13 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { case ".css": w.Header().Set("Content-Type", "text/css; charset=utf-8") default: - errPageLogger.Errorf("unexpected file type %q for %s", ext, filename) + logger.Error().Msgf("unexpected file type %q for %s", ext, filename) } if _, err := w.Write(file); err != nil { - errPageLogger.WithError(err).Errorf("unable to write resource %s", filename) + logger.Err(err).Msg("unable to write resource " + filename) http.Error(w, "Error page failure", http.StatusInternalServerError) } return true } return false } - -var errPageLogger = logrus.WithField("middleware", "error_page") diff --git a/internal/api/v1/errorpage/error_page.go b/internal/net/http/middleware/errorpage/error_page.go similarity index 66% rename from internal/api/v1/errorpage/error_page.go rename to internal/net/http/middleware/errorpage/error_page.go index cb796bf..7f63d5d 100644 --- a/internal/api/v1/errorpage/error_page.go +++ b/internal/net/http/middleware/errorpage/error_page.go @@ -1,14 +1,15 @@ package errorpage import ( - "context" "fmt" "os" "path" "sync" - . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" W "github.com/yusing/go-proxy/internal/watcher" @@ -23,9 +24,10 @@ var ( ) var setup = sync.OnceFunc(func() { - dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath) + task := task.GlobalTask("error page") + dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath) loadContent() - go watchDir() + go watchDir(task) }) func GetStaticFile(filename string) ([]byte, bool) { @@ -44,7 +46,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) { func loadContent() { files, err := U.ListFiles(errPagesBasePath, 0) if err != nil { - Logger.Error(err) + logger.Err(err).Msg("failed to list error page resources") return } for _, file := range files { @@ -53,19 +55,21 @@ func loadContent() { } content, err := os.ReadFile(file) if err != nil { - Logger.Errorf("failed to read error page resource %s: %s", file, err) + logger.Warn().Err(err).Msgf("failed to read error page resource %s", file) continue } file = path.Base(file) - Logger.Infof("error page resource %s loaded", file) + logging.Info().Msgf("error page resource %s loaded", file) fileContentMap.Store(file, content) } } -func watchDir() { - eventCh, errCh := dirWatcher.Events(context.Background()) +func watchDir(task task.Task) { + eventCh, errCh := dirWatcher.Events(task.Context()) for { select { + case <-task.Context().Done(): + return case event, ok := <-eventCh: if !ok { return @@ -77,14 +81,14 @@ func watchDir() { loadContent() case events.ActionFileDeleted: fileContentMap.Delete(filename) - Logger.Infof("error page resource %s deleted", filename) + logger.Warn().Msgf("error page resource %s deleted", filename) case events.ActionFileRenamed: - Logger.Infof("error page resource %s deleted", filename) + logger.Warn().Msgf("error page resource %s deleted", filename) fileContentMap.Delete(filename) loadContent() } case err := <-errCh: - Logger.Errorf("error watching error page directory: %s", err) + E.LogError("error watching error page directory", err, &logger) } } } diff --git a/internal/net/http/middleware/errorpage/logger.go b/internal/net/http/middleware/errorpage/logger.go new file mode 100644 index 0000000..bc0fc30 --- /dev/null +++ b/internal/net/http/middleware/errorpage/logger.go @@ -0,0 +1,5 @@ +package errorpage + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "errorpage").Logger() diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 93e51d5..29ce004 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -24,11 +24,12 @@ type ( client http.Client } forwardAuthOpts struct { - Address string - TrustForwardHeader bool - AuthResponseHeaders []string - AddAuthCookiesToResponse []string - transport http.RoundTripper + Address string `json:"address"` + TrustForwardHeader bool `json:"trustForwardHeader"` + AuthResponseHeaders []string `json:"authResponseHeaders"` + AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"` + + transport http.RoundTripper } ) @@ -39,13 +40,11 @@ var ForwardAuth = &forwardAuth{ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { fa := new(forwardAuth) fa.forwardAuthOpts = new(forwardAuthOpts) - err := Deserialize(optsRaw, fa.forwardAuthOpts) - if err != nil { + if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil { return nil, err } - _, err = E.Check(url.Parse(fa.Address)) - if err != nil { - return nil, E.Invalid("address", fa.Address) + if _, err := url.Parse(fa.Address); err != nil { + return nil, E.From(err) } fa.m = &Middleware{ diff --git a/internal/net/http/middleware/logger.go b/internal/net/http/middleware/logger.go new file mode 100644 index 0000000..643f9b6 --- /dev/null +++ b/internal/net/http/middleware/logger.go @@ -0,0 +1,5 @@ +package middleware + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "middleware").Logger() diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 804450a..0b2b80e 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" U "github.com/yusing/go-proxy/internal/utils" @@ -32,6 +33,8 @@ type ( Middleware struct { _ U.NoCopy + zerolog.Logger + name string before BeforeFunc // runs before ReverseProxy.ServeHTTP @@ -78,13 +81,19 @@ func (m *Middleware) MarshalJSON() ([]byte, error) { } func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) { - if len(optsRaw) != 0 && m.withOptions != nil { - return m.withOptions(optsRaw) + if m.withOptions != nil { + m, err := m.withOptions(optsRaw) + if err != nil { + return nil, err + } + m.Logger = logger.With().Str("name", m.name).Logger() + return m, nil } // WithOptionsClone is called only once // set withOptions and labelParser will not be used after that return &Middleware{ + Logger: logger.With().Str("name", m.name).Logger(), name: m.name, before: m.before, modifyResponse: m.modifyResponse, @@ -108,24 +117,20 @@ func (m *Middleware) ModifyResponse(resp *Response) error { } // TODO: check conflict or duplicates. -func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) { - middlewares = make([]*Middleware, 0, len(middlewaresMap)) +func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { + middlewares := make([]*Middleware, 0, len(middlewaresMap)) - invalidM := E.NewBuilder("invalid middlewares") - invalidOpts := E.NewBuilder("invalid options") - defer func() { - invalidM.Add(invalidOpts.Build()) - invalidM.To(&res) - }() + errs := E.NewBuilder("middlewares compile error") + invalidOpts := E.NewBuilder("options compile error") for name, opts := range middlewaresMap { - m, ok := Get(name) - if !ok { - invalidM.Add(E.NotExist("middleware", name)) + m, err := Get(name) + if err != nil { + errs.Add(err) continue } - m, err := m.WithOptionsClone(opts) + m, err = m.WithOptionsClone(opts) if err != nil { invalidOpts.Add(err.Subject(name)) continue @@ -133,7 +138,10 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid middlewares = append(middlewares, m) } - return + if invalidOpts.HasError() { + errs.Add(invalidOpts.Error()) + } + return middlewares, errs.Error() } func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index 19bd92f..328cc64 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -4,64 +4,60 @@ import ( "fmt" "net/http" "os" + "path" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "gopkg.in/yaml.v3" ) -func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) { +func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware { fileContent, err := os.ReadFile(filePath) if err != nil { - return nil, E.FailWith("read middleware compose file", err) + eb.Add(err) + return nil } - return BuildMiddlewaresFromYAML(fileContent) + return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) } -func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) { - b := E.NewBuilder("middlewares compile errors") - defer b.To(&outErr) - +func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware { var rawMap map[string][]map[string]any err := yaml.Unmarshal(data, &rawMap) if err != nil { - b.Add(E.FailWith("yaml unmarshal", err)) - return + eb.Add(err) + return nil } - middlewares = make(map[string]*Middleware) + middlewares := make(map[string]*Middleware) for name, defs := range rawMap { - chainErr := E.NewBuilder("%s", name) + chainErr := E.NewBuilder("") chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { if def["use"] == nil || def["use"] == "" { - chainErr.Add(E.Missing("use").Subjectf(".%d", i)) + chainErr.Addf("item %d: missing field 'use'", i) continue } baseName := def["use"].(string) - base, ok := Get(baseName) - if !ok { - base, ok = middlewares[baseName] - if !ok { - chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i)) - continue - } + base, err := Get(baseName) + if err != nil { + chainErr.Add(err.Subjectf("%s[%d]", name, i)) + continue } delete(def, "use") m, err := base.WithOptionsClone(def) if err != nil { - chainErr.Add(err.Subjectf("item%d", i)) + chainErr.Add(err.Subjectf("%s[%d]", name, i)) continue } m.name = fmt.Sprintf("%s[%d]", name, i) chain = append(chain, m) } if chainErr.HasError() { - b.Add(chainErr.Build()) + eb.Add(chainErr.Error().Subject(source)) } else { middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain) } } - return + return middlewares } // TODO: check conflict or duplicates. @@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware { } if len(modResps) > 0 { m.modifyResponse = func(res *Response) error { - b := E.NewBuilder("errors in middleware") + errs := E.NewBuilder("modify response errors") for _, mr := range modResps { - b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name)) + if err := mr.modifyResponse(res); err != nil { + errs.Add(E.From(err).Subject(mr.name)) + } } - return b.Build().Error() + return errs.Error() } } diff --git a/internal/net/http/middleware/middleware_builder_test.go b/internal/net/http/middleware/middleware_builder_test.go index d7fca0c..914655a 100644 --- a/internal/net/http/middleware/middleware_builder_test.go +++ b/internal/net/http/middleware/middleware_builder_test.go @@ -13,10 +13,10 @@ import ( var testMiddlewareCompose []byte func TestBuild(t *testing.T) { - middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) - ExpectNoError(t, err.Error()) - _, err = E.Check(json.MarshalIndent(middlewares, "", " ")) - ExpectNoError(t, err.Error()) + errs := E.NewBuilder("") + middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs) + ExpectNoError(t, errs.Error()) + E.Must(json.MarshalIndent(middlewares, "", " ")) // t.Log(string(data)) // TODO: test } diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 26e67a7..4ff3699 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -6,26 +6,37 @@ import ( "path" "strings" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) -var middlewares map[string]*Middleware +var allMiddlewares map[string]*Middleware -func Get(name string) (middleware *Middleware, ok bool) { - middleware, ok = middlewares[U.ToLowerNoSnake(name)] - return +var ( + ErrUnknownMiddleware = E.New("unknown middleware") + ErrDuplicatedMiddleware = E.New("duplicated middleware") +) + +func Get(name string) (*Middleware, Error) { + middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)] + if !ok { + return nil, ErrUnknownMiddleware. + Subject(name). + Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares))) + } + return middleware, nil } func All() map[string]*Middleware { - return middlewares + return allMiddlewares } // initialize middleware names and label parsers func init() { - middlewares = map[string]*Middleware{ + allMiddlewares = map[string]*Middleware{ "setxforwarded": SetXForwarded, "hidexforwarded": HideXForwarded, "redirecthttp": RedirectHTTP, @@ -39,10 +50,10 @@ func init() { // !experimental "forwardauth": ForwardAuth.m, - "oauth2": OAuth2.m, + // "oauth2": OAuth2.m, } names := make(map[*Middleware][]string) - for name, m := range middlewares { + for name, m := range allMiddlewares { names[m] = append(names[m], http.CanonicalHeaderKey(name)) } for m, names := range names { @@ -55,27 +66,30 @@ func init() { } func LoadComposeFiles() { - b := E.NewBuilder("failed to load middlewares") + errs := E.NewBuilder("middleware compile errors") middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0) if err != nil { - logrus.Errorf("failed to list middleware definitions: %s", err) + logger.Err(err).Msg("failed to list middleware definitions") return } for _, defFile := range middlewareDefs { - mws, err := BuildMiddlewaresFromComposeFile(defFile) + mws := BuildMiddlewaresFromComposeFile(defFile, errs) + if len(mws) == 0 { + continue + } for name, m := range mws { - if _, ok := middlewares[name]; ok { - b.Add(E.Duplicated("middleware", name)) + if _, ok := allMiddlewares[name]; ok { + errs.Add(ErrDuplicatedMiddleware.Subject(name)) continue } - middlewares[U.ToLowerNoSnake(name)] = m - logger.Infof("middleware %s loaded from %s", name, path.Base(defFile)) + allMiddlewares[U.ToLowerNoSnake(name)] = m + logger.Info(). + Str("name", name). + Str("src", path.Base(defFile)). + Msg("middleware loaded") } - b.Add(err.Subject(path.Base(defFile))) } - if b.HasError() { - logger.Error(b.Build()) + if errs.HasError() { + E.LogError(errs.About(), errs.Error(), &logger) } } - -var logger = logrus.WithField("module", "middlewares") diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 2b9ca2e..0b0ce60 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -12,9 +12,9 @@ type ( } // order: set_headers -> add_headers -> hide_headers modifyRequestOpts struct { - SetHeaders map[string]string - AddHeaders map[string]string - HideHeaders []string + SetHeaders map[string]string `json:"setHeaders"` + AddHeaders map[string]string `json:"addHeaders"` + HideHeaders []string `json:"hideHeaders"` } ) diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 1590e11..6d9d9ec 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -16,7 +16,7 @@ func TestSetModifyRequest(t *testing.T) { t.Run("set_options", func(t *testing.T) { mr, err := ModifyRequest.m.WithOptionsClone(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string)) @@ -26,7 +26,7 @@ func TestSetModifyRequest(t *testing.T) { result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ middlewareOpt: opts, }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index 62011d8..4edd710 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -13,11 +13,7 @@ type ( m *Middleware } // order: set_headers -> add_headers -> hide_headers - modifyResponseOpts struct { - SetHeaders map[string]string - AddHeaders map[string]string - HideHeaders []string - } + modifyResponseOpts = modifyRequestOpts ) var ModifyResponse = &modifyResponse{ diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 65e98d6..370e590 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -16,7 +16,7 @@ func TestSetModifyResponse(t *testing.T) { t.Run("set_options", func(t *testing.T) { mr, err := ModifyResponse.m.WithOptionsClone(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) @@ -26,7 +26,7 @@ func TestSetModifyResponse(t *testing.T) { result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{ middlewareOpt: opts, }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0") t.Log(result.ResponseHeaders.Get("Accept-Encoding")) ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) diff --git a/internal/net/http/middleware/oauth2.go b/internal/net/http/middleware/oauth2.go index b500bee..10ea6d6 100644 --- a/internal/net/http/middleware/oauth2.go +++ b/internal/net/http/middleware/oauth2.go @@ -1,129 +1,117 @@ package middleware -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "reflect" +// import ( +// "encoding/json" +// "fmt" +// "net/http" +// "net/url" - E "github.com/yusing/go-proxy/internal/error" -) +// E "github.com/yusing/go-proxy/internal/error" +// ) -type oAuth2 struct { - *oAuth2Opts - m *Middleware -} +// type oAuth2 struct { +// oAuth2Opts +// m *Middleware +// } -type oAuth2Opts struct { - ClientID string - ClientSecret string - AuthURL string // Authorization Endpoint - TokenURL string // Token Endpoint -} +// type oAuth2Opts struct { +// ClientID string `validate:"required"` +// ClientSecret string `validate:"required"` +// AuthURL string `validate:"required"` // Authorization Endpoint +// TokenURL string `validate:"required"` // Token Endpoint +// } -var OAuth2 = &oAuth2{ - m: &Middleware{withOptions: NewAuthentikOAuth2}, -} +// var OAuth2 = &oAuth2{ +// m: &Middleware{withOptions: NewAuthentikOAuth2}, +// } -func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { - oauth := new(oAuth2) - oauth.m = &Middleware{ - impl: oauth, - before: oauth.handleOAuth2, - } - oauth.oAuth2Opts = &oAuth2Opts{} - err := Deserialize(opts, oauth.oAuth2Opts) - if err != nil { - return nil, err - } - b := E.NewBuilder("missing required fields") - optV := reflect.ValueOf(oauth.oAuth2Opts) - for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) { - if optV.FieldByName(field.Name).Len() == 0 { - b.Add(E.Missing(field.Name)) - } - } - if b.HasError() { - return nil, b.Build().Subject("oAuth2") - } - return oauth.m, nil -} +// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { +// oauth := new(oAuth2) +// oauth.m = &Middleware{ +// impl: oauth, +// before: oauth.handleOAuth2, +// } +// err := Deserialize(opts, &oauth.oAuth2Opts) +// if err != nil { +// return nil, err +// } +// return oauth.m, nil +// } -func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { - // Check if the user is authenticated (you may use session, cookie, etc.) - if !userIsAuthenticated(r) { - // TODO: Redirect to OAuth2 auth URL - http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", - oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) - return - } +// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { +// // Check if the user is authenticated (you may use session, cookie, etc.) +// if !userIsAuthenticated(r) { +// // TODO: Redirect to OAuth2 auth URL +// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", +// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) +// return +// } - // If you have a token in the query string, process it - if code := r.URL.Query().Get("code"); code != "" { - // Exchange the authorization code for a token here - // Use the TokenURL and authenticate the user - token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI) - if err != nil { - // handle error - http.Error(rw, "failed to get token", http.StatusUnauthorized) - return - } +// // If you have a token in the query string, process it +// if code := r.URL.Query().Get("code"); code != "" { +// // Exchange the authorization code for a token here +// // Use the TokenURL and authenticate the user +// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI) +// if err != nil { +// // handle error +// http.Error(rw, "failed to get token", http.StatusUnauthorized) +// return +// } - // Save token and user info based on your requirements - saveToken(rw, token) +// // Save token and user info based on your requirements +// saveToken(rw, token) - // Redirect to the originally requested URL - http.Redirect(rw, r, "/", http.StatusFound) - return - } +// // Redirect to the originally requested URL +// http.Redirect(rw, r, "/", http.StatusFound) +// return +// } - // If user is authenticated, go to the next handler - next(rw, r) -} +// // If user is authenticated, go to the next handler +// next(rw, r) +// } -func userIsAuthenticated(r *http.Request) bool { - // Example: Check for a session or cookie - session, err := r.Cookie("session_token") - if err != nil || session.Value == "" { - return false - } - // Validate the session_token if necessary - return true -} +// func userIsAuthenticated(r *http.Request) bool { +// // Example: Check for a session or cookie +// session, err := r.Cookie("session_token") +// if err != nil || session.Value == "" { +// return false +// } +// // Validate the session_token if necessary +// return true +// } -func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { - // Prepare the request body - data := url.Values{ - "client_id": {opts.ClientID}, - "client_secret": {opts.ClientSecret}, - "code": {code}, - "grant_type": {"authorization_code"}, - "redirect_uri": {requestURI}, - } - resp, err := http.PostForm(opts.TokenURL, data) - if err != nil { - return "", fmt.Errorf("failed to request token: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) - } - // Decode the response - var tokenResp struct { - AccessToken string `json:"access_token"` - } - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", fmt.Errorf("failed to decode token response: %w", err) - } - return tokenResp.AccessToken, nil -} +// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { +// // Prepare the request body +// data := url.Values{ +// "client_id": {opts.ClientID}, +// "client_secret": {opts.ClientSecret}, +// "code": {code}, +// "grant_type": {"authorization_code"}, +// "redirect_uri": {requestURI}, +// } +// resp, err := http.PostForm(opts.TokenURL, data) +// if err != nil { +// return "", fmt.Errorf("failed to request token: %w", err) +// } +// defer resp.Body.Close() +// if resp.StatusCode != http.StatusOK { +// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) +// } +// // Decode the response +// var tokenResp struct { +// AccessToken string `json:"access_token"` +// } +// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { +// return "", fmt.Errorf("failed to decode token response: %w", err) +// } +// return tokenResp.AccessToken, nil +// } -func saveToken(rw ResponseWriter, token string) { - // Example: Save token in cookie - http.SetCookie(rw, &http.Cookie{ - Name: "auth_token", - Value: token, - // set other properties as necessary, such as Secure and HttpOnly - }) -} +// func saveToken(rw ResponseWriter, token string) { +// // Example: Save token in cookie +// http.SetCookie(rw, &http.Cookie{ +// Name: "auth_token", +// Value: token, +// // set other properties as necessary, such as Secure and HttpOnly +// }) +// } diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 12d674c..0bbf452 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -16,9 +16,9 @@ type realIP struct { type realIPOpts struct { // Header is the name of the header to use for the real client IP - Header string + Header string `json:"header"` // From is a list of Address / CIDRs to trust - From []*types.CIDR + From []*types.CIDR `json:"from"` /* If recursive search is disabled, the original client address that matches one of the trusted addresses is replaced by @@ -27,7 +27,7 @@ type realIPOpts struct { the original client address that matches one of the trusted addresses is replaced by the last non-trusted address sent in the request header field. */ - Recursive bool + Recursive bool `json:"recursive"` } var RealIP = &realIP{ diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index f47c272..67457b8 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) { } ri, err := NewRealIP(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) for i, CIDR := range ri.impl.(*realIP).From { @@ -61,15 +61,15 @@ func TestSetRealIP(t *testing.T) { "set_headers": map[string]string{testHeader: testRealIP}, } realip, err := NewRealIP(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) mr, err := NewModifyRequest(optsMr) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip}) result, err := newMiddlewareTest(mid, nil) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) t.Log(traces) ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) diff --git a/internal/net/http/middleware/redirect_http_test.go b/internal/net/http/middleware/redirect_http_test.go index 6d2b2f6..b591fc2 100644 --- a/internal/net/http/middleware/redirect_http_test.go +++ b/internal/net/http/middleware/redirect_http_test.go @@ -12,7 +12,7 @@ func TestRedirectToHTTPs(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ scheme: "http", }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect) ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort) } @@ -21,6 +21,6 @@ func TestNoRedirect(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ scheme: "https", }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusOK) } diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index 8c169e2..2b444f3 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -6,7 +6,7 @@ import ( "time" gphttp "github.com/yusing/go-proxy/internal/net/http" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type Trace struct { @@ -88,7 +88,7 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace { return nil } return addTrace(&Trace{ - Time: U.FormatTime(time.Now()), + Time: strutils.FormatTime(time.Now()), Caller: m.Fullname(), Message: fmt.Sprintf(msg, args...), }) diff --git a/internal/net/http/modify_response_writer.go b/internal/net/http/modify_response_writer.go index 8ba0d72..a7a82c2 100644 --- a/internal/net/http/modify_response_writer.go +++ b/internal/net/http/modify_response_writer.go @@ -57,8 +57,7 @@ func (w *ModifyResponseWriter) WriteHeader(code int) { } if err := w.modifier(&resp); err != nil { - w.modifierErr = err - logger.Errorf("error modifying response: %s", err) + w.modifierErr = fmt.Errorf("response modifier error: %w", err) w.w.WriteHeader(http.StatusInternalServerError) return } diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index b3afc79..b4d4448 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -23,7 +23,7 @@ import ( "strings" "sync" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" "golang.org/x/net/http/httpguts" @@ -69,6 +69,8 @@ type ProxyRequest struct { // 1xx responses are forwarded to the client if the underlying // transport supports ClientTrace.Got1xxResponse. type ReverseProxy struct { + zerolog.Logger + // The transport used to perform proxy requests. // If nil, http.DefaultTransport is used. Transport http.RoundTripper @@ -149,7 +151,12 @@ func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) if transport == nil { panic("nil transport") } - rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target} + rp := &ReverseProxy{ + Logger: logger.With().Str("name", name).Logger(), + Transport: transport, + TargetName: name, + TargetURL: target, + } rp.ServeHTTP = rp.serveHTTP return rp } @@ -195,9 +202,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err switch { case errors.Is(err, context.Canceled), errors.Is(err, io.EOF): - logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) + logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error") default: - logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) + logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error") } if writeHeader { rw.WriteHeader(http.StatusBadGateway) @@ -219,6 +226,10 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response } func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { + if _, ok := rw.(DummyResponseWriter); ok { + return + } + transport := p.Transport ctx := req.Context() @@ -453,6 +464,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R resUpType := UpgradeType(res.Header) if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) + return } if !strings.EqualFold(reqUpType, resUpType) { p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true) @@ -518,5 +530,3 @@ func IsPrint(s string) bool { } return true } - -var logger = logrus.WithField("module", "http") diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index 230ca16..65fc269 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -9,10 +9,15 @@ import ( type CIDR net.IPNet +var ( + ErrInvalidCIDR = E.New("invalid CIDR") + ErrInvalidCIDRType = E.New("invalid CIDR type") +) + func (cidr *CIDR) ConvertFrom(val any) E.Error { cidrStr, ok := val.(string) if !ok { - return E.TypeMismatch[string](val) + return ErrInvalidCIDRType.Subjectf("%T", val) } if !strings.Contains(cidrStr, "/") { @@ -20,7 +25,7 @@ func (cidr *CIDR) ConvertFrom(val any) E.Error { } _, ipnet, err := net.ParseCIDR(cidrStr) if err != nil { - return E.Invalid("CIDR", cidr) + return ErrInvalidCIDR.Subject(cidrStr) } *cidr = CIDR(*ipnet) return nil diff --git a/internal/net/types/stream.go b/internal/net/types/stream.go index 871521f..2892a5f 100644 --- a/internal/net/types/stream.go +++ b/internal/net/types/stream.go @@ -5,9 +5,28 @@ import ( "net" ) -type Stream interface { - fmt.Stringer - net.Listener - Setup() error - Handle(conn net.Conn) error +type ( + Stream interface { + fmt.Stringer + StreamListener + Setup() error + Handle(conn StreamConn) error + } + StreamListener interface { + Addr() net.Addr + Accept() (StreamConn, error) + Close() error + } + StreamConn any + NetListenerWrapper struct { + net.Listener + } +) + +func NetListener(l net.Listener) StreamListener { + return NetListenerWrapper{Listener: l} +} + +func (l NetListenerWrapper) Accept() (StreamConn, error) { + return l.Listener.Accept() } diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index 7d9d0e4..2b74f42 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -1,22 +1,32 @@ package notif import ( - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( Dispatcher struct { task task.Task - logCh chan *logrus.Entry + logCh chan *LogMessage providers F.Set[Provider] } + LogMessage struct { + Level zerolog.Level + Title, Message string + } ) var dispatcher *Dispatcher +var ErrUnknownNotifProvider = E.New("unknown notification provider") + +const dispatchErr = "notification dispatch error" + func init() { dispatcher = newNotifDispatcher() go dispatcher.start() @@ -25,7 +35,7 @@ func init() { func newNotifDispatcher() *Dispatcher { return &Dispatcher{ task: task.GlobalTask("notif dispatcher"), - logCh: make(chan *logrus.Entry), + logCh: make(chan *LogMessage), providers: F.NewSet[Provider](), } } @@ -34,11 +44,13 @@ func GetDispatcher() *Dispatcher { return dispatcher } -func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.Error) { +func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, error) { name := configSubTask.Name() createFunc, ok := Providers[name] if !ok { - return nil, E.NotExist("provider", name) + return nil, ErrUnknownNotifProvider. + Subject(name). + Withf(strutils.DoYouMean(utils.NearestField(name, Providers))) } if provider, err := createFunc(cfg); err != nil { return nil, err @@ -53,7 +65,6 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E. func (disp *Dispatcher) start() { defer dispatcher.task.Finish("dispatcher stopped") - defer close(dispatcher.logCh) for { select { @@ -65,36 +76,39 @@ func (disp *Dispatcher) start() { } } -func (disp *Dispatcher) dispatch(entry *logrus.Entry) { +func (disp *Dispatcher) dispatch(msg *LogMessage) { task := disp.task.Subtask("dispatch notif") - defer task.Finish("notifs dispatched") + defer task.Finish("notif dispatched") - errs := E.NewBuilder("errors sending notif") + errs := E.NewBuilder(dispatchErr) disp.providers.RangeAllParallel(func(p Provider) { - if err := p.Send(task.Context(), entry); err != nil { - errs.Addf("%s: %s", p.Name(), err) + if err := p.Send(task.Context(), msg); err != nil { + errs.Add(E.PrependSubject(p.Name(), err)) } }) - if err := errs.Build(); err != nil { - logrus.Error("notif dispatcher failure: ", err) + if errs.HasError() { + E.LogError(errs.About(), errs.Error()) } } -// Levels implements logrus.Hook. -func (disp *Dispatcher) Levels() []logrus.Level { - return []logrus.Level{ - logrus.WarnLevel, - logrus.ErrorLevel, - logrus.FatalLevel, - logrus.PanicLevel, - } -} +// Run implements zerolog.Hook. +// func (disp *Dispatcher) Run(e *zerolog.Event, level zerolog.Level, message string) { +// if strings.HasPrefix(message, dispatchErr) { // prevent recursion +// return +// } +// switch level { +// case zerolog.WarnLevel, zerolog.ErrorLevel, zerolog.FatalLevel, zerolog.PanicLevel: +// disp.logCh <- &LogMessage{ +// Level: level, +// Message: message, +// } +// } +// } -// Fire implements logrus.Hook. -func (disp *Dispatcher) Fire(entry *logrus.Entry) error { - if disp.providers.Size() == 0 { - return nil +func Notify(title, msg string) { + dispatcher.logCh <- &LogMessage{ + Level: zerolog.InfoLevel, + Title: title, + Message: msg, } - disp.logCh <- entry - return nil } diff --git a/internal/notif/gotify.go b/internal/notif/gotify.go index ff6be73..90c9e01 100644 --- a/internal/notif/gotify.go +++ b/internal/notif/gotify.go @@ -9,7 +9,7 @@ import ( "net/url" "github.com/gotify/server/v2/model" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" U "github.com/yusing/go-proxy/internal/utils" ) @@ -39,7 +39,7 @@ func newGotifyClient(cfg map[string]any) (Provider, E.Error) { url, uErr := url.Parse(client.URL) if uErr != nil { - return nil, E.FailWith("parse url", uErr) + return nil, E.Errorf("invalid gotify URL %s", client.URL) } client.url = url @@ -52,30 +52,23 @@ func (client *GotifyClient) Name() string { } // Send implements NotifProvider. -func (client *GotifyClient) Send(ctx context.Context, entry *logrus.Entry) error { +func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error { var priority int - var title string - switch entry.Level { - case logrus.WarnLevel: + switch logMsg.Level { + case zerolog.WarnLevel: priority = 2 - title = "Warning" - case logrus.ErrorLevel: + case zerolog.ErrorLevel: priority = 5 - title = "Error" - case logrus.FatalLevel, logrus.PanicLevel: + case zerolog.FatalLevel, zerolog.PanicLevel: priority = 8 - title = "Critical" default: return nil } - if subjects := FieldsAsTitle(entry); subjects != "" { - title = subjects + " " + title - } msg := &GotifyMessage{ - Title: title, - Message: entry.Message, + Title: logMsg.Title, + Message: logMsg.Message, Priority: priority, } diff --git a/internal/notif/logrus.go b/internal/notif/logrus.go deleted file mode 100644 index 6ca0bf1..0000000 --- a/internal/notif/logrus.go +++ /dev/null @@ -1,21 +0,0 @@ -package notif - -import ( - "fmt" - "strings" - - "github.com/sirupsen/logrus" - U "github.com/yusing/go-proxy/internal/utils" -) - -func FieldsAsTitle(entry *logrus.Entry) string { - if len(entry.Data) == 0 { - return "" - } - var parts []string - for k, v := range entry.Data { - parts = append(parts, fmt.Sprintf("%s: %s", k, v)) - } - parts[0] = U.Title(parts[0]) - return strings.Join(parts, ", ") -} diff --git a/internal/notif/providers.go b/internal/notif/providers.go index 33b3dc5..5ab6e9a 100644 --- a/internal/notif/providers.go +++ b/internal/notif/providers.go @@ -3,14 +3,13 @@ package notif import ( "context" - "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" ) type ( Provider interface { Name() string - Send(ctx context.Context, entry *logrus.Entry) error + Send(ctx context.Context, logMsg *LogMessage) error } ProviderCreateFunc func(map[string]any) (Provider, E.Error) ProviderConfig map[string]any diff --git a/internal/proxy/entry/entry.go b/internal/proxy/entry/entry.go index 0c17ae8..4eb71f3 100644 --- a/internal/proxy/entry/entry.go +++ b/internal/proxy/entry/entry.go @@ -23,18 +23,18 @@ func ValidateEntry(m *RawEntry) (Entry, E.Error) { scheme, err := T.NewScheme(m.Scheme) if err != nil { - return nil, err + return nil, E.From(err) } var entry Entry - e := E.NewBuilder("error validating entry") + errs := E.NewBuilder("entry validation failed") if scheme.IsStream() { - entry = validateStreamEntry(m, e) + entry = validateStreamEntry(m, errs) } else { - entry = validateRPEntry(m, scheme, e) + entry = validateRPEntry(m, scheme, errs) } - if err := e.Build(); err != nil { - return nil, err + if errs.HasError() { + return nil, errs.Error() } return entry, nil } diff --git a/internal/proxy/entry/raw.go b/internal/proxy/entry/raw.go index 6135618..1ebe78d 100644 --- a/internal/proxy/entry/raw.go +++ b/internal/proxy/entry/raw.go @@ -5,13 +5,14 @@ import ( "strings" "github.com/docker/docker/api/types" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -85,20 +86,20 @@ func (e *RawEntry) FillMissingFields() { } else if !isDocker { pp = "80" } else { - logrus.Debugf("no port found for %s", e.Alias) + logging.Debug().Msg("no port found for " + e.Alias) } } // replace private port with public port if using public IP. if e.Host == cont.PublicIP { if p, ok := cont.PrivatePortMapping[pp]; ok { - pp = U.PortString(p.PublicPort) + pp = strutils.PortString(p.PublicPort) } } // replace public port with private port if using private IP. if e.Host == cont.PrivateIP { if p, ok := cont.PublicPortMapping[pp]; ok { - pp = U.PortString(p.PrivatePort) + pp = strutils.PortString(p.PrivatePort) } } diff --git a/internal/proxy/entry/reverse_proxy.go b/internal/proxy/entry/reverse_proxy.go index 8d48652..976dde3 100644 --- a/internal/proxy/entry/reverse_proxy.go +++ b/internal/proxy/entry/reverse_proxy.go @@ -53,7 +53,7 @@ func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config { return rp.Idlewatcher } -func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry { +func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProxyEntry { cont := m.Container if cont == nil { cont = docker.DummyContainer @@ -64,35 +64,26 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEnt lb = nil } - host, err := fields.ValidateHost(m.Host) - b.Add(err) + host := E.Collect(errs, fields.ValidateHost, m.Host) + port := E.Collect(errs, fields.ValidatePort, m.Port) + pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns) + url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port)) + iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container) - port, err := fields.ValidatePort(m.Port) - b.Add(err) - - pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns) - b.Add(err) - - url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) - b.Add(err) - - idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container) - b.Add(err) - - if err != nil { + if errs.HasError() { return nil } return &ReverseProxyEntry{ Raw: m, - Alias: fields.NewAlias(m.Alias), + Alias: fields.Alias(m.Alias), Scheme: s, URL: net.NewURL(url), NoTLSVerify: m.NoTLSVerify, - PathPatterns: pathPatterns, + PathPatterns: pathPats, HealthCheck: m.HealthCheck, LoadBalance: lb, Middlewares: m.Middlewares, - Idlewatcher: idleWatcherCfg, + Idlewatcher: iwCfg, } } diff --git a/internal/proxy/entry/stream.go b/internal/proxy/entry/stream.go index 4416b20..5ffabe8 100644 --- a/internal/proxy/entry/stream.go +++ b/internal/proxy/entry/stream.go @@ -51,34 +51,25 @@ func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config { return s.Idlewatcher } -func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry { +func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry { cont := m.Container if cont == nil { cont = docker.DummyContainer } - host, err := fields.ValidateHost(m.Host) - b.Add(err) + host := E.Collect(errs, fields.ValidateHost, m.Host) + port := E.Collect(errs, fields.ValidateStreamPort, m.Port) + scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme) + url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort)) + idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container) - port, err := fields.ValidateStreamPort(m.Port) - b.Add(err) - - scheme, err := fields.ValidateStreamScheme(m.Scheme) - b.Add(err) - - url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort))) - b.Add(err) - - idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container) - b.Add(err) - - if b.HasError() { + if errs.HasError() { return nil } return &StreamEntry{ Raw: m, - Alias: fields.NewAlias(m.Alias), + Alias: fields.Alias(m.Alias), Scheme: *scheme, URL: url, Host: host, diff --git a/internal/proxy/fields/alias.go b/internal/proxy/fields/alias.go index 289f964..07a91eb 100644 --- a/internal/proxy/fields/alias.go +++ b/internal/proxy/fields/alias.go @@ -1,6 +1,3 @@ package fields -type ( - Alias string - NewAlias = Alias -) +type Alias string diff --git a/internal/proxy/fields/host.go b/internal/proxy/fields/host.go index 68c17c2..892e72c 100644 --- a/internal/proxy/fields/host.go +++ b/internal/proxy/fields/host.go @@ -1,14 +1,10 @@ package fields -import ( - E "github.com/yusing/go-proxy/internal/error" -) - type ( Host string Subdomain = Alias ) -func ValidateHost[String ~string](s String) (Host, E.Error) { +func ValidateHost[String ~string](s String) (Host, error) { return Host(s), nil } diff --git a/internal/proxy/fields/path_pattern.go b/internal/proxy/fields/path_pattern.go index 8d9abd3..29b7cf4 100644 --- a/internal/proxy/fields/path_pattern.go +++ b/internal/proxy/fields/path_pattern.go @@ -1,6 +1,8 @@ package fields import ( + "errors" + "fmt" "regexp" E "github.com/yusing/go-proxy/internal/error" @@ -13,12 +15,17 @@ type ( var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) -func ValidatePathPattern(s string) (PathPattern, E.Error) { +var ( + ErrEmptyPathPattern = errors.New("path must not be empty") + ErrInvalidPathPattern = errors.New("invalid path pattern") +) + +func ValidatePathPattern(s string) (PathPattern, error) { if len(s) == 0 { - return "", E.Invalid("path", "must not be empty") + return "", ErrEmptyPathPattern } if !pathPattern.MatchString(s) { - return "", E.Invalid("path pattern", s) + return "", fmt.Errorf("%w %q", ErrInvalidPathPattern, s) } return PathPattern(s), nil } @@ -27,13 +34,15 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.Error) { if len(s) == 0 { return []PathPattern{"/"}, nil } + errs := E.NewBuilder("invalid path patterns") pp := make(PathPatterns, len(s)) for i, v := range s { pattern, err := ValidatePathPattern(v) if err != nil { - return nil, err + errs.Add(err) + } else { + pp[i] = pattern } - pp[i] = pattern } - return pp, nil + return pp, errs.Error() } diff --git a/internal/proxy/fields/path_pattern_test.go b/internal/proxy/fields/path_pattern_test.go index d19cb97..261ee3f 100644 --- a/internal/proxy/fields/path_pattern_test.go +++ b/internal/proxy/fields/path_pattern_test.go @@ -1,9 +1,9 @@ package fields import ( + "errors" "testing" - E "github.com/yusing/go-proxy/internal/error" U "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -38,10 +38,10 @@ var invalidPatterns = []string{ func TestPathPatternRegex(t *testing.T) { for _, pattern := range validPatterns { _, err := ValidatePathPattern(pattern) - U.ExpectNoError(t, err.Error()) + U.ExpectNoError(t, err) } for _, pattern := range invalidPatterns { _, err := ValidatePathPattern(pattern) - U.ExpectError2(t, pattern, E.ErrInvalid, err.Error()) + U.ExpectTrue(t, errors.Is(err, ErrInvalidPathPattern)) } } diff --git a/internal/proxy/fields/port.go b/internal/proxy/fields/port.go index 5780005..9b809fd 100644 --- a/internal/proxy/fields/port.go +++ b/internal/proxy/fields/port.go @@ -4,22 +4,25 @@ import ( "strconv" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type Port int -func ValidatePort[String ~string](v String) (Port, E.Error) { - p, err := strconv.Atoi(string(v)) +var ErrPortOutOfRange = E.New("port out of range") + +func ValidatePort[String ~string](v String) (Port, error) { + p, err := strutils.Atoi(string(v)) if err != nil { - return ErrPort, E.Invalid("port number", v).With(err) + return ErrPort, err } return ValidatePortInt(p) } -func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) { +func ValidatePortInt[Int int | uint16](v Int) (Port, error) { p := Port(v) if !p.inBound() { - return ErrPort, E.OutOfRange("port", p) + return ErrPort, ErrPortOutOfRange.Subject(strconv.Itoa(int(p))) } return p, nil } diff --git a/internal/proxy/fields/scheme.go b/internal/proxy/fields/scheme.go index 2e4f6e5..b006464 100644 --- a/internal/proxy/fields/scheme.go +++ b/internal/proxy/fields/scheme.go @@ -6,12 +6,14 @@ import ( type Scheme string -func NewScheme[String ~string](s String) (Scheme, E.Error) { +var ErrInvalidScheme = E.New("invalid scheme") + +func NewScheme(s string) (Scheme, error) { switch s { case "http", "https", "tcp", "udp": return Scheme(s), nil } - return "", E.Invalid("scheme", s) + return "", ErrInvalidScheme.Subject(s) } func (s Scheme) IsHTTP() bool { return s == "http" } diff --git a/internal/proxy/fields/stream_port.go b/internal/proxy/fields/stream_port.go index 020d455..6cbc767 100644 --- a/internal/proxy/fields/stream_port.go +++ b/internal/proxy/fields/stream_port.go @@ -3,7 +3,6 @@ package fields import ( "strings" - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" ) @@ -12,7 +11,9 @@ type StreamPort struct { ProxyPort Port `json:"proxy"` } -func ValidateStreamPort(p string) (_ StreamPort, err E.Error) { +var ErrStreamPortTooManyColons = E.New("too many colons") + +func ValidateStreamPort(p string) (StreamPort, error) { split := strings.Split(p, ":") switch len(split) { @@ -21,36 +22,14 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.Error) { case 2: break default: - err = E.Invalid("stream port", p).With("too many colons") - return + return StreamPort{}, ErrStreamPortTooManyColons.Subject(p) } - listeningPort, err := ValidatePort(split[0]) - if err != nil { - err = err.Subject("listening port") - return - } - - proxyPort, err := ValidatePort(split[1]) - - if err.Is(E.ErrOutOfRange) { - err = err.Subject("proxy port") - return - } else if err != nil { - proxyPort, err = parseNameToPort(split[1]) - if err != nil { - err = E.Invalid("proxy port", proxyPort) - return - } + listeningPort, lErr := ValidatePort(split[0]) + proxyPort, pErr := ValidatePort(split[1]) + if err := E.Join(lErr, pErr); err != nil { + return StreamPort{}, err } return StreamPort{listeningPort, proxyPort}, nil } - -func parseNameToPort(name string) (Port, E.Error) { - port, ok := common.ServiceNamePortMapTCP[name] - if !ok { - return ErrPort, E.Invalid("service", name) - } - return Port(port), nil -} diff --git a/internal/proxy/fields/stream_port_test.go b/internal/proxy/fields/stream_port_test.go index fce707c..a18732a 100644 --- a/internal/proxy/fields/stream_port_test.go +++ b/internal/proxy/fields/stream_port_test.go @@ -1,9 +1,9 @@ package fields import ( + "strconv" "testing" - E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -11,7 +11,6 @@ var validPorts = []string{ "1234:5678", "0:2345", "2345", - "1234:postgres", } var invalidPorts = []string{ @@ -19,7 +18,6 @@ var invalidPorts = []string{ "123:", "0:", ":1234", - "1234:1234:1234", "qwerty", "asdfgh:asdfgh", "1234:asdfgh", @@ -32,17 +30,25 @@ var outOfRangePorts = []string{ "0:65536", } +var tooManyColonsPorts = []string{ + "1234:1234:1234", +} + func TestStreamPort(t *testing.T) { for _, port := range validPorts { _, err := ValidateStreamPort(port) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) } for _, port := range invalidPorts { _, err := ValidateStreamPort(port) - ExpectError2(t, port, E.ErrInvalid, err.Error()) + ExpectError2(t, port, strconv.ErrSyntax, err) } for _, port := range outOfRangePorts { _, err := ValidateStreamPort(port) - ExpectError2(t, port, E.ErrOutOfRange, err.Error()) + ExpectError2(t, port, ErrPortOutOfRange, err) + } + for _, port := range tooManyColonsPorts { + _, err := ValidateStreamPort(port) + ExpectError2(t, port, ErrStreamPortTooManyColons, err) } } diff --git a/internal/proxy/fields/stream_scheme.go b/internal/proxy/fields/stream_scheme.go index d195a29..0c0180e 100644 --- a/internal/proxy/fields/stream_scheme.go +++ b/internal/proxy/fields/stream_scheme.go @@ -12,22 +12,23 @@ type StreamScheme struct { ProxyScheme Scheme `json:"proxy"` } -func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) { - ss = &StreamScheme{} +func ValidateStreamScheme(s string) (*StreamScheme, error) { + ss := &StreamScheme{} parts := strings.Split(s, ":") if len(parts) == 1 { parts = []string{s, s} } else if len(parts) != 2 { - return nil, E.Invalid("stream scheme", s) + return nil, ErrInvalidScheme.Subject(s) } - ss.ListeningScheme, err = NewScheme(parts[0]) - if err.HasError() { - return nil, err - } - ss.ProxyScheme, err = NewScheme(parts[1]) - if err.HasError() { + + var lErr, pErr error + ss.ListeningScheme, lErr = NewScheme(parts[0]) + ss.ProxyScheme, pErr = NewScheme(parts[1]) + + if err := E.Join(lErr, pErr); err != nil { return nil, err } + return ss, nil } diff --git a/internal/proxy/fields/stream_scheme_test.go b/internal/proxy/fields/stream_scheme_test.go new file mode 100644 index 0000000..15227c0 --- /dev/null +++ b/internal/proxy/fields/stream_scheme_test.go @@ -0,0 +1,37 @@ +package fields + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +var ( + validStreamSchemes = []string{ + "tcp:tcp", + "tcp:udp", + "udp:tcp", + "udp:udp", + "tcp", + "udp", + } + + invalidStreamSchemes = []string{ + "tcp:tcp:", + "tcp:", + ":udp:", + ":udp", + "top", + } +) + +func TestNewStreamScheme(t *testing.T) { + for _, s := range validStreamSchemes { + _, err := ValidateStreamScheme(s) + ExpectNoError(t, err) + } + for _, s := range invalidStreamSchemes { + _, err := ValidateStreamScheme(s) + ExpectError(t, ErrInvalidScheme, err) + } +} diff --git a/internal/route/http.go b/internal/route/http.go index 0010e3c..f2b40a7 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -7,13 +7,13 @@ import ( "strings" "sync" - "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/api/v1/errorpage" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" "github.com/yusing/go-proxy/internal/proxy/entry" PT "github.com/yusing/go-proxy/internal/proxy/fields" "github.com/yusing/go-proxy/internal/task" @@ -33,6 +33,8 @@ type ( rp *gphttp.ReverseProxy task task.Task + + l zerolog.Logger } SubdomainKey = PT.Alias @@ -88,6 +90,10 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) { ReverseProxyEntry: entry, rp: rp, task: task.DummyTask(), + l: logger.With(). + Str("type", string(entry.Scheme)). + Str("name", string(entry.Alias)). + Logger(), } return r, nil } @@ -107,11 +113,11 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { defer httpRoutesMu.Unlock() if !entry.UseHealthCheck(r) && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { - logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) + r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled") if r.HealthCheck == nil { r.HealthCheck = new(health.HealthCheckConfig) } - r.HealthCheck.Disable = true + r.HealthCheck.Disable = false } switch { @@ -143,7 +149,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { if r.HealthMon != nil { if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { - logrus.Warn(E.FailWith("health monitor", err)) + E.LogWarn("health monitor error", err, &r.l) } } @@ -209,15 +215,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) { // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. if err != nil { if !middleware.ServeStaticErrorPageFile(w, r) { - logrus.Error(E.Failure("request"). - Subjectf("%s %s", r.Method, r.URL.String()). - With(err)) + logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) if ok { w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "text/html; charset=utf-8") if _, err := w.Write(errorPage); err != nil { - logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err) + logger.Err(err).Msg("failed to write error page") } } else { http.Error(w, err.Error(), http.StatusNotFound) diff --git a/internal/route/logger.go b/internal/route/logger.go new file mode 100644 index 0000000..caf8f51 --- /dev/null +++ b/internal/route/logger.go @@ -0,0 +1,5 @@ +package route + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "route").Logger() diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go index 440e47b..c1dca89 100755 --- a/internal/route/provider/docker.go +++ b/internal/route/provider/docker.go @@ -6,77 +6,90 @@ import ( "strings" "github.com/docker/docker/client" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" - D "github.com/yusing/go-proxy/internal/docker" + "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/proxy/entry" - R "github.com/yusing/go-proxy/internal/route" - W "github.com/yusing/go-proxy/internal/watcher" + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/utils/strutils" + "github.com/yusing/go-proxy/internal/watcher" ) type DockerProvider struct { name, dockerHost string ExplicitOnly bool + l zerolog.Logger } var ( AliasRefRegex = regexp.MustCompile(`#\d+`) AliasRefRegexOld = regexp.MustCompile(`\$\d+`) + + ErrAliasRefIndexOutOfRange = E.New("index out of range") ) -func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) { +func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, error) { if dockerHost == common.DockerHostFromEnv { dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost) } - return &DockerProvider{name, dockerHost, explicitOnly}, nil + return &DockerProvider{ + name, + dockerHost, + explicitOnly, + logger.With().Str("type", "docker").Str("name", name).Logger(), + }, nil } func (p *DockerProvider) String() string { return "docker@" + p.name } -func (p *DockerProvider) NewWatcher() W.Watcher { - return W.NewDockerWatcher(p.dockerHost) +func (p *DockerProvider) Logger() *zerolog.Logger { + return &p.l } -func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) { - routes = R.NewRoutes() +func (p *DockerProvider) NewWatcher() watcher.Watcher { + return watcher.NewDockerWatcher(p.dockerHost) +} + +func (p *DockerProvider) loadRoutesImpl() (route.Routes, E.Error) { + routes := route.NewRoutes() entries := entry.NewProxyEntries() - containers, err := D.ListContainers(p.dockerHost) + containers, err := docker.ListContainers(p.dockerHost) if err != nil { - return routes, err + return routes, E.From(err) } - errors := E.NewBuilder("errors in docker labels") + errs := E.NewBuilder("") for _, c := range containers { - container := D.FromDocker(&c, p.dockerHost) + container := docker.FromDocker(&c, p.dockerHost) if container.IsExcluded { continue } newEntries, err := p.entriesFromContainerLabels(container) if err != nil { - errors.Add(err) + errs.Add(err.Subject(container.ContainerName)) } // although err is not nil // there may be some valid entries in `en` dups := entries.MergeFrom(newEntries) // add the duplicate proxy entries to the error dups.RangeAll(func(k string, v *entry.RawEntry) { - errors.Addf("duplicate alias %s", k) + errs.Addf("duplicated alias %s", k) }) } - routes, err = R.FromEntries(entries) - errors.Add(err) + routes, err = route.FromEntries(entries) + errs.Add(err) - return routes, errors.Build() + return routes, errs.Error() } -func (p *DockerProvider) shouldIgnore(container *D.Container) bool { +func (p *DockerProvider) shouldIgnore(container *docker.Container) bool { return container.IsExcluded || !container.IsExplicit && p.ExplicitOnly || !container.IsExplicit && container.IsDatabase || @@ -85,7 +98,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool { // Returns a list of proxy entries for a container. // Always non-nil. -func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) { +func (p *DockerProvider) entriesFromContainerLabels(container *docker.Container) (entries entry.RawEntries, _ E.Error) { entries = entry.NewProxyEntries() if p.shouldIgnore(container) { @@ -100,9 +113,9 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent }) } - errors := E.NewBuilder("failed to apply label") + errs := E.NewBuilder("label errors") for key, val := range container.Labels { - errors.Add(p.applyLabel(container, entries, key, val)) + errs.Add(p.applyLabel(container, entries, key, val)) } // remove all entries that failed to fill in missing fields @@ -110,59 +123,56 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent re.FillMissingFields() }) - return entries, errors.Build().Subject(container.ContainerName) + return entries, errs.Error() } -func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) { - b := E.NewBuilder("errors in label %s", key) - defer b.To(&res) +func (p *DockerProvider) applyLabel(container *docker.Container, entries entry.RawEntries, key, val string) E.Error { + lbl := docker.ParseLabel(key, val) + if lbl.Namespace != docker.NSProxy { + return nil + } + if lbl.Target == docker.WildcardAlias { + // apply label for all aliases + labelErrs := entries.CollectErrors(func(a string, e *entry.RawEntry) error { + return docker.ApplyLabel(e, lbl) + }) + if err := E.Join(labelErrs...); err != nil { + return err.Subject(lbl.Target) + } + return nil + } - refErr := E.NewBuilder("errors in alias references") + refErrs := E.NewBuilder("alias ref errors") replaceIndexRef := func(ref string) string { - index, err := strconv.Atoi(ref[1:]) + index, err := strutils.Atoi(ref[1:]) if err != nil { - refErr.Add(E.Invalid("integer", ref)) + refErrs.Add(err) return ref } if index < 1 || index > len(container.Aliases) { - refErr.Add(E.OutOfRange("index", ref)) + refErrs.Add(ErrAliasRefIndexOutOfRange.Subject(strconv.Itoa(index))) return ref } return container.Aliases[index-1] } - lbl, err := D.ParseLabel(key, val) - if err != nil { - b.Add(err.Subject(key)) + lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef) + lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(ref string) string { + p.l.Warn().Msgf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#")) + return replaceIndexRef(ref) + }) + if refErrs.HasError() { + return refErrs.Error().Subject(lbl.String()) } - if lbl.Namespace != D.NSProxy { - return + + en, ok := entries.Load(lbl.Target) + if !ok { + en = &entry.RawEntry{ + Alias: lbl.Target, + Container: container, + } + entries.Store(lbl.Target, en) } - if lbl.Target == D.WildcardAlias { - // apply label for all aliases - entries.RangeAll(func(a string, e *entry.RawEntry) { - if err = D.ApplyLabel(e, lbl); err != nil { - b.Add(err) - } - }) - } else { - lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef) - lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string { - logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#")) - return replaceIndexRef(s) - }) - if refErr.HasError() { - b.Add(refErr.Build()) - return - } - config, ok := entries.Load(lbl.Target) - if !ok { - b.Add(E.NotExist("alias", lbl.Target)) - return - } - if err = D.ApplyLabel(config, lbl); err != nil { - b.Add(err) - } - } - return + + return docker.ApplyLabel(en, lbl) } diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index 43d3cc0..dfd0558 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -1,8 +1,8 @@ package provider import ( - "strings" "testing" + "time" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/network" @@ -20,7 +20,7 @@ var ( p DockerProvider ) -func TestApplyLabelWildcard(t *testing.T) { +func TestApplyLabel(t *testing.T) { pathPatterns := ` - / - POST /upload/{$} @@ -61,9 +61,13 @@ func TestApplyLabelWildcard(t *testing.T) { "proxy.a.middlewares.middleware1.prop2": "value2", "proxy.a.middlewares.middleware2.prop3": "value3", "proxy.a.middlewares.middleware2.prop4": "value4", + "proxy.a.homepage.show": "true", + "proxy.a.homepage.icon": "png/example.png", + "proxy.a.healthcheck.path": "/ping", + "proxy.a.healthcheck.interval": "10s", }, }, client.DefaultDockerHost)) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) a, ok := entries.Load("a") ExpectTrue(t, ok) @@ -102,6 +106,12 @@ func TestApplyLabelWildcard(t *testing.T) { ExpectEqual(t, a.Container.StopSignal, "SIGTERM") ExpectEqual(t, b.Container.StopSignal, "SIGTERM") + + ExpectEqual(t, a.Homepage.Show, true) + ExpectEqual(t, a.Homepage.Icon, "png/example.png") + + ExpectEqual(t, a.HealthCheck.Path, "/ping") + ExpectEqual(t, a.HealthCheck.Interval, 10*time.Second) } func TestApplyLabelWithAlias(t *testing.T) { @@ -123,7 +133,7 @@ func TestApplyLabelWithAlias(t *testing.T) { c, ok := entries.Load("c") ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, a.Scheme, "http") ExpectEqual(t, a.Port, "3333") ExpectEqual(t, a.NoTLSVerify, true) @@ -133,7 +143,7 @@ func TestApplyLabelWithAlias(t *testing.T) { } func TestApplyLabelWithRef(t *testing.T) { - entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, State: "running", Labels: map[string]string{ @@ -171,8 +181,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { }, }, "") _, err := p.entriesFromContainerLabels(c) - ExpectError(t, E.ErrOutOfRange, err.Error()) - ExpectTrue(t, strings.Contains(err.String(), "index out of range")) + ExpectError(t, ErrAliasRefIndexOutOfRange, err) _, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, @@ -182,13 +191,33 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { "proxy.#0.host": "localhost", }, }, "")) - ExpectError(t, E.ErrOutOfRange, err.Error()) - ExpectTrue(t, strings.Contains(err.String(), "index out of range")) + ExpectError(t, ErrAliasRefIndexOutOfRange, err) +} + +func TestDynamicAliases(t *testing.T) { + c := D.FromDocker(&types.Container{ + Names: []string{"app1"}, + State: "running", + Labels: map[string]string{ + "proxy.app1.port": "1234", + "proxy.app1_backend.port": "5678", + }, + }, client.DefaultDockerHost) + + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("app1") + ExpectTrue(t, ok) + ExpectEqual(t, raw.Scheme, "http") + ExpectEqual(t, raw.Port, "1234") + + raw, ok = E.Must(p.entriesFromContainerLabels(c)).Load("app1_backend") + ExpectTrue(t, ok) + ExpectEqual(t, raw.Scheme, "http") + ExpectEqual(t, raw.Port, "5678") } func TestPublicIPLocalhost(t *testing.T) { c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, client.DefaultDockerHost) - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1") ExpectEqual(t, raw.Host, raw.Container.PublicIP) @@ -196,7 +225,7 @@ func TestPublicIPLocalhost(t *testing.T) { func TestPublicIPRemote(t *testing.T) { c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, "tcp://1.2.3.4:2375") - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") ExpectEqual(t, raw.Host, raw.Container.PublicIP) @@ -213,7 +242,7 @@ func TestPrivateIPLocalhost(t *testing.T) { }, }, }, client.DefaultDockerHost) - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PrivateIP, "172.17.0.123") ExpectEqual(t, raw.Host, raw.Container.PrivateIP) @@ -231,7 +260,7 @@ func TestPrivateIPRemote(t *testing.T) { }, }, }, "tcp://1.2.3.4:2375") - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PrivateIP, "") ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") @@ -260,9 +289,9 @@ func TestStreamDefaultValues(t *testing.T) { t.Run("local", func(t *testing.T) { c := D.FromDocker(cont, client.DefaultDockerHost) - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - en := Must(entry.ValidateEntry(raw)) + en := E.Must(entry.ValidateEntry(raw)) a := ExpectType[*entry.StreamEntry](t, en) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) @@ -273,9 +302,9 @@ func TestStreamDefaultValues(t *testing.T) { t.Run("remote", func(t *testing.T) { c := D.FromDocker(cont, "tcp://1.2.3.4:2375") - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - en := Must(entry.ValidateEntry(raw)) + en := E.Must(entry.ValidateEntry(raw)) a := ExpectType[*entry.StreamEntry](t, en) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) @@ -286,7 +315,7 @@ func TestStreamDefaultValues(t *testing.T) { } func TestExplicitExclude(t *testing.T) { - _, ok := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + _, ok := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ D.LabelAliases: "a", @@ -299,7 +328,7 @@ func TestExplicitExclude(t *testing.T) { func TestImplicitExcludeDatabase(t *testing.T) { t.Run("mount path detection", func(t *testing.T) { - _, ok := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + _, ok := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Mounts: []types.MountPoint{ {Source: "/data", Destination: "/var/lib/postgresql/data"}, @@ -308,7 +337,7 @@ func TestImplicitExcludeDatabase(t *testing.T) { ExpectFalse(t, ok) }) t.Run("exposed port detection", func(t *testing.T) { - _, ok := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + _, ok := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Ports: []types.Port{ {Type: "tcp", PrivatePort: 5432, PublicPort: 5432}, @@ -317,58 +346,3 @@ func TestImplicitExcludeDatabase(t *testing.T) { ExpectFalse(t, ok) }) } - -// func TestImplicitExcludeNoExposedPort(t *testing.T) { -// var p DockerProvider -// entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ -// Image: "redis", -// Names: []string{"redis"}, -// Ports: []types.Port{ -// {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed -// }, -// State: "running", -// }, "")) -// ExpectNoError(t, err.Error()) - -// _, ok := entries.Load("redis") -// ExpectFalse(t, ok) -// } - -// func TestNotExcludeSpecifiedPort(t *testing.T) { -// var p DockerProvider -// entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ -// Image: "redis", -// Names: []string{"redis"}, -// Ports: []types.Port{ -// {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed -// }, -// Labels: map[string]string{ -// "proxy.redis.port": "6379:6379", // but specified in label -// }, -// }, "")) -// ExpectNoError(t, err.Error()) - -// _, ok := entries.Load("redis") -// ExpectTrue(t, ok) -// } - -// func TestNotExcludeNonExposedPortHostNetwork(t *testing.T) { -// var p DockerProvider -// cont := &types.Container{ -// Image: "redis", -// Names: []string{"redis"}, -// Ports: []types.Port{ -// {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed -// }, -// Labels: map[string]string{ -// "proxy.redis.port": "6379:6379", -// }, -// } -// cont.HostConfig.NetworkMode = "host" - -// entries, err := p.entriesFromContainerLabels(D.FromDocker(cont, "")) -// ExpectNoError(t, err.Error()) - -// _, ok := entries.Load("redis") -// ExpectTrue(t, ok) -// } diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 0118cd2..6c990ce 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -12,25 +12,27 @@ import ( type EventHandler struct { provider *Provider - added []string - removed []string - paused []string - updated []string - errs E.Builder + errs *E.Builder + added *E.Builder + removed *E.Builder + updated *E.Builder } func (provider *Provider) newEventHandler() *EventHandler { return &EventHandler{ provider: provider, errs: E.NewBuilder("event errors"), + added: E.NewBuilder("added"), + removed: E.NewBuilder("removed"), + updated: E.NewBuilder("updated"), } } func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { oldRoutes := handler.provider.routes - newRoutes, err := handler.provider.LoadRoutesImpl() + newRoutes, err := handler.provider.loadRoutesImpl() if err != nil { - handler.errs.Add(err.Subject("load routes")) + handler.errs.Add(err) if newRoutes.Size() == 0 { return } @@ -41,17 +43,19 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { for _, event := range events { eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID) } - handler.provider.l.Debug(eventsLog.String()) + E.LogDebug(eventsLog.About(), eventsLog.Error(), handler.provider.Logger()) + oldRoutesLog := E.NewBuilder("old routes") - oldRoutes.RangeAll(func(k string, r *route.Route) { - oldRoutesLog.Addf(k) + oldRoutes.RangeAllParallel(func(k string, r *route.Route) { + oldRoutesLog.Adds(k) }) - handler.provider.l.Debug(oldRoutesLog.String()) + E.LogDebug(oldRoutesLog.About(), oldRoutesLog.Error(), handler.provider.Logger()) + newRoutesLog := E.NewBuilder("new routes") - newRoutes.RangeAll(func(k string, r *route.Route) { - newRoutesLog.Addf(k) + newRoutes.RangeAllParallel(func(k string, r *route.Route) { + newRoutesLog.Adds(k) }) - handler.provider.l.Debug(newRoutesLog.String()) + E.LogDebug(newRoutesLog.About(), newRoutesLog.Error(), handler.provider.Logger()) } oldRoutes.RangeAll(func(k string, oldr *route.Route) { @@ -95,41 +99,35 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool func (handler *EventHandler) Add(parent task.Task, route *route.Route) { err := handler.provider.startRoute(parent, route) if err != nil { - handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err)) + handler.errs.Add(err.Subject("add")) } else { - handler.added = append(handler.added, route.Entry.Alias) + handler.added.Adds(route.Entry.Alias) } } func (handler *EventHandler) Remove(route *route.Route) { route.Finish("route removed") handler.provider.routes.Delete(route.Entry.Alias) - handler.removed = append(handler.removed, route.Entry.Alias) + handler.removed.Adds(route.Entry.Alias) } func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) { oldRoute.Finish("route update") err := handler.provider.startRoute(parent, newRoute) if err != nil { - handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err)) + handler.errs.Add(err.Subject("update")) } else { - handler.updated = append(handler.updated, newRoute.Entry.Alias) + handler.updated.Adds(newRoute.Entry.Alias) } } func (handler *EventHandler) Log() { results := E.NewBuilder("event occured") - for _, alias := range handler.added { - results.Addf("added %s", alias) - } - for _, alias := range handler.removed { - results.Addf("removed %s", alias) - } - for _, alias := range handler.updated { - results.Addf("updated %s", alias) - } - results.Add(handler.errs.Build()) - if result := results.Build(); result != nil { - handler.provider.l.Info(result) + results.Add(handler.added.Error()) + results.Add(handler.removed.Error()) + results.Add(handler.updated.Error()) + results.Add(handler.errs.Error()) + if result := results.String(); result != "" { + handler.provider.Logger().Info().Msg(result) } } diff --git a/internal/route/provider/file.go b/internal/route/provider/file.go index 8f302e0..754b680 100644 --- a/internal/route/provider/file.go +++ b/internal/route/provider/file.go @@ -1,11 +1,10 @@ package provider import ( - "errors" "os" "path" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/proxy/entry" @@ -17,53 +16,49 @@ import ( type FileProvider struct { fileName string path string + l zerolog.Logger } -func FileProviderImpl(filename string) (ProviderImpl, E.Error) { +func FileProviderImpl(filename string) (ProviderImpl, error) { impl := &FileProvider{ fileName: filename, path: path.Join(common.ConfigBasePath, filename), + l: logger.With().Str("type", "file").Str("name", filename).Logger(), } _, err := os.Stat(impl.path) - switch { - case err == nil: - return impl, nil - case errors.Is(err, os.ErrNotExist): - return nil, E.NotExist("file", impl.path) - default: - return nil, E.UnexpectedError(err) + if err != nil { + return nil, err } + return impl, nil } func Validate(data []byte) E.Error { return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data) } -func (p FileProvider) String() string { +func (p *FileProvider) String() string { return p.fileName } -func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) { - routes = R.NewRoutes() - - b := E.NewBuilder("validation failure") - defer b.To(&res) +func (p *FileProvider) Logger() *zerolog.Logger { + return &p.l +} +func (p *FileProvider) loadRoutesImpl() (R.Routes, E.Error) { + routes := R.NewRoutes() entries := entry.NewProxyEntries() - data, err := E.Check(os.ReadFile(p.path)) + data, err := os.ReadFile(p.path) if err != nil { - b.Add(E.FailWith("read file", err)) - return + return routes, E.From(err) } - if err = entries.UnmarshalFromYAML(data); err != nil { - b.Add(err) - return + if err := entries.UnmarshalFromYAML(data); err != nil { + return routes, E.From(err) } if err := Validate(data); err != nil { - logrus.Warn(err) + E.LogWarn(p.fileName+": validation failure", err) } return R.FromEntries(entries) diff --git a/internal/route/provider/logger.go b/internal/route/provider/logger.go new file mode 100644 index 0000000..d734627 --- /dev/null +++ b/internal/route/provider/logger.go @@ -0,0 +1,5 @@ +package provider + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "provider").Logger() diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 195d910..e6cb01e 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -1,11 +1,12 @@ package provider import ( + "errors" "fmt" "path" "time" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/task" @@ -22,13 +23,12 @@ type ( routes R.Routes watcher W.Watcher - - l *logrus.Entry } ProviderImpl interface { fmt.Stringer + loadRoutesImpl() (R.Routes, E.Error) NewWatcher() W.Watcher - LoadRoutesImpl() (R.Routes, E.Error) + Logger() *zerolog.Logger } ProviderType string ProviderStats struct { @@ -45,20 +45,22 @@ const ( providerEventFlushInterval = 300 * time.Millisecond ) +var ( + ErrEmptyProviderName = errors.New("empty provider name") +) + func newProvider(name string, t ProviderType) *Provider { - p := &Provider{ + return &Provider{ name: name, t: t, routes: R.NewRoutes(), } - p.l = logrus.WithField("provider", p) - return p } -func NewFileProvider(filename string) (p *Provider, err E.Error) { +func NewFileProvider(filename string) (p *Provider, err error) { name := path.Base(filename) if name == "" { - return nil, E.Invalid("file name", "empty") + return nil, ErrEmptyProviderName } p = newProvider(name, ProviderTypeFile) p.ProviderImpl, err = FileProviderImpl(filename) @@ -69,9 +71,9 @@ func NewFileProvider(filename string) (p *Provider, err E.Error) { return } -func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) { +func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) { if name == "" { - return nil, E.Invalid("provider name", "empty") + return nil, ErrEmptyProviderName } p = newProvider(name, ProviderTypeDocker) @@ -106,7 +108,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { if err != nil { p.routes.Delete(r.Entry.Alias) subtask.Finish(err) // just to ensure - return err + return err.Subject(r.Entry.Alias) } else { p.routes.Store(r.Entry.Alias, r) subtask.OnFinished("del from provider", func() { @@ -117,16 +119,14 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { } // Start implements task.TaskStarter. -func (p *Provider) Start(configSubtask task.Task) (res E.Error) { - errors := E.NewBuilder("errors starting routes") - defer errors.To(&res) - +func (p *Provider) Start(configSubtask task.Task) E.Error { // routes and event queue will stop on parent cancel providerTask := configSubtask - p.routes.RangeAllParallel(func(alias string, r *R.Route) { - errors.Add(p.startRoute(providerTask, r)) - }) + errs := p.routes.CollectErrorsParallel( + func(alias string, r *R.Route) error { + return p.startRoute(providerTask, r) + }) eventQueue := events.NewEventQueue( providerTask, @@ -139,11 +139,15 @@ func (p *Provider) Start(configSubtask task.Task) (res E.Error) { flushTask.Finish("events flushed") }, func(err E.Error) { - p.l.Error(err) + E.LogError("event error", err, p.Logger()) }, ) eventQueue.Start(p.watcher.Events(providerTask.Context())) - return + + if err := E.Join(errs...); err != nil { + return err.Subject(p.String()) + } + return nil } func (p *Provider) RangeRoutes(do func(string, *R.Route)) { @@ -156,14 +160,14 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) { func (p *Provider) LoadRoutes() E.Error { var err E.Error - p.routes, err = p.LoadRoutesImpl() + p.routes, err = p.loadRoutesImpl() if p.routes.Size() > 0 { return err } if err == nil { return nil } - return E.FailWith("loading routes", err) + return err } func (p *Provider) NumRoutes() int { diff --git a/internal/route/raw.go b/internal/route/raw.go deleted file mode 100644 index f206a74..0000000 --- a/internal/route/raw.go +++ /dev/null @@ -1,94 +0,0 @@ -package route - -import ( - "errors" - "fmt" - "net" - "time" - - T "github.com/yusing/go-proxy/internal/proxy/fields" - U "github.com/yusing/go-proxy/internal/utils" -) - -type ( - RawStream struct { - *StreamRoute - - listener net.Listener - targetAddr net.Addr - } -) - -const ( - streamBufferSize = 8192 - streamDialTimeout = 5 * time.Second -) - -func NewRawStreamRoute(base *StreamRoute) *RawStream { - return &RawStream{ - StreamRoute: base, - } -} - -func (route *RawStream) Setup() error { - var lcfg net.ListenConfig - var err error - - switch route.Scheme.ListeningScheme { - case "tcp": - route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) - if err != nil { - return err - } - tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) - if err != nil { - return err - } - route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) - route.listener = tcpListener - case "udp": - route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) - if err != nil { - return err - } - udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort)) - if err != nil { - return err - } - route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port) - route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener) - default: - return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme)) - } - - return nil -} - -func (route *RawStream) Accept() (net.Conn, error) { - if route.listener == nil { - return nil, errors.New("listener not yet set up") - } - return route.listener.Accept() -} - -func (route *RawStream) Handle(c net.Conn) error { - clientConn := c.(net.Conn) - - defer clientConn.Close() - route.task.OnCancel("close conn", func() { clientConn.Close() }) - - dialer := &net.Dialer{Timeout: streamDialTimeout} - - serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) - serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr) - if err != nil { - return err - } - - pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) - return pipe.Start() -} - -func (route *RawStream) Close() error { - return route.listener.Close() -} diff --git a/internal/route/route.go b/internal/route/route.go index 93b3a4c..8cace61 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -89,5 +89,5 @@ func FromEntries(entries entry.RawEntries) (Routes, E.Error) { } }) - return routes, b.Build() + return routes, b.Error() } diff --git a/internal/route/stream.go b/internal/route/stream.go index 957a166..c94d75b 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -6,7 +6,7 @@ import ( "fmt" "sync" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" net "github.com/yusing/go-proxy/internal/net/types" @@ -18,13 +18,14 @@ import ( type StreamRoute struct { *entry.StreamEntry - net.Stream `json:"-"` + + stream net.Stream `json:"-"` HealthMon health.HealthMonitor `json:"health"` task task.Task - l logrus.FieldLogger + l zerolog.Logger } var ( @@ -39,11 +40,15 @@ func GetStreamProxies() F.Map[string, *StreamRoute] { func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) { // TODO: support non-coherent scheme if !entry.Scheme.IsCoherent() { - return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) + return nil, E.Errorf("unsupported scheme: %v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme) } return &StreamRoute{ StreamEntry: entry, task: task.DummyTask(), + l: logger.With(). + Str("type", string(entry.Scheme.ListeningScheme)). + Str("name", entry.TargetName()). + Logger(), }, nil } @@ -62,57 +67,54 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.Error { defer streamRoutesMu.Unlock() if r.HealthCheck.Disable && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { - logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) - r.HealthCheck.Disable = true + r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled") + r.HealthCheck.Disable = false } - // if r.Scheme.ListeningScheme.IsTCP() { - // r.Stream = NewTCPRoute(r) - // } else { - // r.Stream = NewUDPRoute(r) - // } r.task = providerSubtask - r.Stream = NewRawStreamRoute(r) - r.l = logrus.WithField("route", r.Stream.String()) + r.stream = NewStream(r) switch { case entry.UseIdleWatcher(r): wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias)) - waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream) + waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.stream) if err != nil { r.task.Finish(err) return err } - r.Stream = waker + r.stream = waker r.HealthMon = waker case entry.UseHealthCheck(r): r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck) } - if err := r.Setup(); err != nil { + if err := r.stream.Setup(); err != nil { r.task.Finish(err) - return E.FailWith("setup", err) + return E.From(err) } r.task.OnFinished("close stream", func() { - if err := r.Close(); err != nil { - r.l.Error("close stream error: ", err) + if err := r.stream.Close(); err != nil { + E.LogError("close stream failed", err, &r.l) } }) - r.task.OnFinished("remove from route table", func() { - streamRoutes.Delete(string(r.Alias)) - }) - r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort) + r.l.Info(). + Str("proto", string(r.Scheme.ListeningScheme)). + Int("port", int(r.Port.ListeningPort)). + Msg("listening") if r.HealthMon != nil { if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { - logrus.Warn("health monitor error: ", err) + E.LogWarn("health monitor error", err, &r.l) } } go r.acceptConnections() streamRoutes.Store(string(r.Alias), r) + r.task.OnFinished("remove from route table", func() { + streamRoutes.Delete(string(r.Alias)) + }) return nil } @@ -128,25 +130,28 @@ func (r *StreamRoute) acceptConnections() { case <-r.task.Context().Done(): return default: - conn, err := r.Accept() + conn, err := r.stream.Accept() if err != nil { select { case <-r.task.Context().Done(): default: - r.l.Error("accept connection error: ", err) - r.task.Finish(err) + E.LogError("accept connection error", err, &r.l) } + r.task.Finish(err) return } - connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr())) + if conn == nil { + panic("connection is nil") + } + connTask := r.task.Subtask("connection") go func() { - err := r.Handle(conn) + err := r.stream.Handle(conn) if err != nil && !errors.Is(err, context.Canceled) { - r.l.Error(err) + E.LogError("handle connection error", err, &r.l) + connTask.Finish(err) } else { - connTask.Finish("connection closed") + connTask.Finish("closed") } - conn.Close() }() } } diff --git a/internal/route/stream_impl.go b/internal/route/stream_impl.go new file mode 100644 index 0000000..6b29dc9 --- /dev/null +++ b/internal/route/stream_impl.go @@ -0,0 +1,115 @@ +package route + +import ( + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/yusing/go-proxy/internal/net/types" + T "github.com/yusing/go-proxy/internal/proxy/fields" + U "github.com/yusing/go-proxy/internal/utils" +) + +type ( + Stream struct { + *StreamRoute + + listener types.StreamListener + targetAddr net.Addr + } +) + +const ( + streamFirstConnBufferSize = 128 + streamDialTimeout = 5 * time.Second +) + +func NewStream(base *StreamRoute) *Stream { + return &Stream{ + StreamRoute: base, + } +} + +func (stream *Stream) Addr() net.Addr { + if stream.listener == nil { + panic("listener is nil") + } + return stream.listener.Addr() +} + +func (stream *Stream) Setup() error { + var lcfg net.ListenConfig + var err error + + switch stream.Scheme.ListeningScheme { + case "tcp": + stream.targetAddr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%v", stream.Host, stream.Port.ProxyPort)) + if err != nil { + return err + } + tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", fmt.Sprintf(":%v", stream.Port.ListeningPort)) + if err != nil { + return err + } + stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) + stream.listener = types.NetListener(tcpListener) + case "udp": + stream.targetAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", stream.Host, stream.Port.ProxyPort)) + if err != nil { + return err + } + udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", fmt.Sprintf(":%v", stream.Port.ListeningPort)) + if err != nil { + return err + } + udpConn, ok := udpListener.(*net.UDPConn) + if !ok { + udpListener.Close() + return errors.New("udp listener is not *net.UDPConn") + } + stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port) + stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr) + default: + panic("should not reach here") + } + + return nil +} + +func (stream *Stream) Accept() (types.StreamConn, error) { + if stream.listener == nil { + return nil, errors.New("listener is nil") + } + return stream.listener.Accept() +} + +func (stream *Stream) Handle(conn types.StreamConn) error { + switch conn := conn.(type) { + case *UDPConn: + switch stream := stream.listener.(type) { + case *UDPForwarder: + return stream.Handle(conn) + default: + return fmt.Errorf("unexpected listener type: %T", stream) + } + case io.ReadWriteCloser: + stream.task.OnCancel("close conn", func() { conn.Close() }) + + dialer := &net.Dialer{Timeout: streamDialTimeout} + dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) + if err != nil { + return err + } + defer dstConn.Close() + pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn) + return pipe.Start() + default: + return fmt.Errorf("unexpected conn type: %T", conn) + } +} + +func (stream *Stream) Close() error { + return stream.listener.Close() +} diff --git a/internal/route/tcp.go b/internal/route/tcp.go deleted file mode 100755 index d14b482..0000000 --- a/internal/route/tcp.go +++ /dev/null @@ -1,68 +0,0 @@ -package route - -// import ( -// "context" -// "fmt" -// "net" -// "time" - -// "github.com/yusing/go-proxy/internal/net/types" -// T "github.com/yusing/go-proxy/internal/proxy/fields" -// U "github.com/yusing/go-proxy/internal/utils" -// F "github.com/yusing/go-proxy/internal/utils/functional" -// ) - -// const tcpDialTimeout = 5 * time.Second - -// type ( -// TCPConnMap = F.Map[net.Conn, struct{}] -// TCPRoute struct { -// *StreamRoute -// listener *net.TCPListener -// } -// ) - -// func NewTCPRoute(base *StreamRoute) *TCPRoute { -// return &TCPRoute{StreamRoute: base} -// } - -// func (route *TCPRoute) Setup() error { -// var cfg net.ListenConfig -// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) -// if err != nil { -// return err -// } -// //! this read the allocated port from original ':0' -// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) -// route.listener = in.(*net.TCPListener) -// return nil -// } - -// func (route *TCPRoute) Accept() (types.StreamConn, error) { -// return route.listener.Accept() -// } - -// func (route *TCPRoute) Handle(c types.StreamConn) error { -// clientConn := c.(net.Conn) - -// defer clientConn.Close() -// route.task.OnCancel("close conn", func() { clientConn.Close() }) - -// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) - -// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) -// dialer := &net.Dialer{} - -// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) -// cancel() -// if err != nil { -// return err -// } - -// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) -// return pipe.Start() -// } - -// func (route *TCPRoute) Close() error { -// return route.listener.Close() -// } diff --git a/internal/route/udp.go b/internal/route/udp.go deleted file mode 100755 index 8f7ae89..0000000 --- a/internal/route/udp.go +++ /dev/null @@ -1,149 +0,0 @@ -package route - -// import ( -// "errors" -// "fmt" -// "io" -// "net" - -// "github.com/yusing/go-proxy/internal/net/types" -// T "github.com/yusing/go-proxy/internal/proxy/fields" -// U "github.com/yusing/go-proxy/internal/utils" -// F "github.com/yusing/go-proxy/internal/utils/functional" -// ) - -// type ( -// UDPRoute struct { -// *StreamRoute - -// connMap UDPConnMap - -// listeningConn net.PacketConn -// targetAddr *net.UDPAddr -// } -// UDPConn struct { -// key string -// src net.Conn -// dst net.Conn -// U.BidirectionalPipe -// } -// UDPConnMap = F.Map[string, *UDPConn] -// ) - -// var NewUDPConnMap = F.NewMap[UDPConnMap] - -// const udpBufferSize = 8192 - -// func NewUDPRoute(base *StreamRoute) *UDPRoute { -// return &UDPRoute{ -// StreamRoute: base, -// connMap: NewUDPConnMap(), -// } -// } - -// func (route *UDPRoute) Setup() error { -// var cfg net.ListenConfig -// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort)) -// if err != nil { -// return err -// } -// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) -// if err != nil { -// source.Close() -// return err -// } - -// //! this read the allocated listeningPort from original ':0' -// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) - -// route.listeningConn = source -// route.targetAddr = raddr - -// return nil -// } - -// func (route *UDPRoute) Accept() (types.StreamConn, error) { -// in := route.listeningConn - -// buffer := make([]byte, udpBufferSize) -// nRead, srcAddr, err := in.ReadFrom(buffer) -// if err != nil { -// return nil, err -// } - -// if nRead == 0 { -// return nil, io.ErrShortBuffer -// } - -// key := srcAddr.String() -// conn, ok := route.connMap.Load(key) - -// if !ok { -// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String()) -// if err != nil { -// return nil, err -// } -// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String()) -// if err != nil { -// srcConn.Close() -// return nil, err -// } -// conn = &UDPConn{ -// key, -// srcConn, -// dstConn, -// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), -// } -// route.connMap.Store(key, conn) -// } - -// _, err = conn.dst.Write(buffer[:nRead]) -// return conn, err -// } - -// func (route *UDPRoute) Handle(c types.StreamConn) error { -// switch c := c.(type) { -// case *UDPConn: -// err := c.Start() -// route.connMap.Delete(c.key) -// c.Close() -// return err -// case *net.TCPConn: -// in := route.listeningConn -// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr)) -// if err != nil { -// return err -// } -// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start() -// c.Close() -// return err -// } -// return fmt.Errorf("unknown conn type: %T", c) -// } - -// func (route *UDPRoute) Close() error { -// route.connMap.RangeAllParallel(func(k string, v *UDPConn) { -// v.Close() -// }) -// route.connMap.Clear() -// return route.listeningConn.Close() -// } - -// // Close implements types.StreamConn -// func (conn *UDPConn) Close() error { -// return errors.Join(conn.src.Close(), conn.dst.Close()) -// } - -// // RemoteAddr implements types.StreamConn -// func (conn *UDPConn) RemoteAddr() net.Addr { -// return conn.src.RemoteAddr() -// } - -// type sourceRWCloser struct { -// server net.PacketConn -// net.Conn -// } - -// func (w sourceRWCloser) Write(p []byte) (int, error) { -// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr)) -// } diff --git a/internal/route/udp_forwarder.go b/internal/route/udp_forwarder.go new file mode 100644 index 0000000..8ed4532 --- /dev/null +++ b/internal/route/udp_forwarder.go @@ -0,0 +1,197 @@ +package route + +import ( + "context" + "fmt" + "net" + "sync" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/types" + F "github.com/yusing/go-proxy/internal/utils/functional" +) + +type ( + UDPForwarder struct { + ctx context.Context + forwarder *net.UDPConn + dstAddr net.Addr + connMap F.Map[string, *UDPConn] + mu sync.Mutex + } + UDPConn struct { + srcAddr *net.UDPAddr + conn net.Conn + buf *UDPBuf + } + UDPBuf struct { + data, oob []byte + n, oobn int + } +) + +const udpConnBufferSize = 4096 + +func NewUDPForwarder(ctx context.Context, forwarder *net.UDPConn, dstAddr net.Addr) *UDPForwarder { + return &UDPForwarder{ + ctx: ctx, + forwarder: forwarder, + dstAddr: dstAddr, + connMap: F.NewMapOf[string, *UDPConn](), + } +} + +func newUDPBuf() *UDPBuf { + return &UDPBuf{ + data: make([]byte, udpConnBufferSize), + oob: make([]byte, udpConnBufferSize), + } +} + +func (conn *UDPConn) SrcAddrString() string { + return conn.srcAddr.Network() + "://" + conn.srcAddr.String() +} + +func (conn *UDPConn) DstAddrString() string { + return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String() +} + +func (w *UDPForwarder) Addr() net.Addr { + return w.forwarder.LocalAddr() +} + +func (w *UDPForwarder) Accept() (types.StreamConn, error) { + buf := newUDPBuf() + addr, err := w.readFromListener(buf) + if err != nil { + return nil, err + } + return &UDPConn{ + srcAddr: addr, + buf: buf, + }, nil +} + +func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) { + switch dstAddr := w.dstAddr.(type) { + case *net.UDPAddr: + var laddr *net.UDPAddr + if dstAddr.IP.IsLoopback() { + laddr, _ = net.ResolveUDPAddr(dstAddr.Network(), "127.0.0.1:") + } + dstConn, err = net.DialUDP(w.dstAddr.Network(), laddr, dstAddr) + case *net.TCPAddr: + dstConn, err = net.DialTCP(w.dstAddr.Network(), nil, dstAddr) + default: + err = fmt.Errorf("unsupported network %s", w.dstAddr.Network()) + } + return +} + +func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) { + buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob) + if err == nil { + logger.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn) + } + return +} + +func (dst *UDPConn) read() (err error) { + switch dstConn := dst.conn.(type) { + case *net.UDPConn: + dst.buf.n, dst.buf.oobn, _, _, err = dstConn.ReadMsgUDP(dst.buf.data, dst.buf.oob) + default: + dst.buf.n, err = dstConn.Read(dst.buf.data[:dst.buf.n]) + dst.buf.oobn = 0 + } + if err == nil { + logger.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", dst.DstAddrString(), dst.buf.n, dst.buf.oobn) + } + return +} + +func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) { + buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr) + if err == nil { + logger.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn) + } + return +} + +func (dst *UDPConn) write() (err error) { + switch dstConn := dst.conn.(type) { + case *net.UDPConn: + dst.buf.n, dst.buf.oobn, err = dstConn.WriteMsgUDP(dst.buf.data[:dst.buf.n], dst.buf.oob[:dst.buf.oobn], nil) + if err == nil { + logger.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", dst.DstAddrString(), dst.buf.n, dst.buf.oobn) + } + default: + _, err = dstConn.Write(dst.buf.data[:dst.buf.n]) + if err == nil { + logger.Debug().Msgf("write to dst %s success (n: %d)", dst.DstAddrString(), dst.buf.n) + } + } + + return nil +} + +func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { + conn, ok := streamConn.(*UDPConn) + if !ok { + panic("unexpected conn type") + } + key := conn.srcAddr.String() + + w.mu.Lock() + dst, ok := w.connMap.Load(key) + if !ok { + var err error + dst = conn + dst.conn, err = w.dialDst() + if err != nil { + return err + } + if err := dst.write(); err != nil { + dst.conn.Close() + return err + } + w.connMap.Store(key, dst) + } else { + conn.conn = dst.conn + if err := conn.write(); err != nil { + w.connMap.Delete(key) + dst.conn.Close() + return err + } + } + w.mu.Unlock() + + for { + select { + case <-w.ctx.Done(): + return nil + default: + if err := dst.read(); err != nil { + w.connMap.Delete(key) + dst.conn.Close() + return err + } + + if err := w.writeToSrc(dst.srcAddr, dst.buf); err != nil { + return err + } + } + } +} + +func (w *UDPForwarder) Close() error { + errs := E.NewBuilder("errors closing udp conn") + w.mu.Lock() + defer w.mu.Unlock() + w.connMap.RangeAll(func(key string, conn *UDPConn) { + errs.Add(conn.conn.Close()) + }) + w.connMap.Clear() + errs.Add(w.forwarder.Close()) + return errs.Error() +} diff --git a/internal/route/udp_listener.go b/internal/route/udp_listener.go deleted file mode 100644 index 04f3b65..0000000 --- a/internal/route/udp_listener.go +++ /dev/null @@ -1,73 +0,0 @@ -package route - -import ( - "context" - "io" - "net" - "sync" - - F "github.com/yusing/go-proxy/internal/utils/functional" -) - -type ( - UDPListener struct { - ctx context.Context - listener net.PacketConn - connMap UDPConnMap - mu sync.Mutex - } - UDPConnMap = F.Map[string, net.Conn] -) - -var NewUDPConnMap = F.NewMap[UDPConnMap] - -func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener { - return &UDPListener{ - ctx: ctx, - listener: listener, - connMap: NewUDPConnMap(), - } -} - -// Addr implements net.Listener. -func (route *UDPListener) Addr() net.Addr { - return route.listener.LocalAddr() -} - -func (udpl *UDPListener) Accept() (net.Conn, error) { - in := udpl.listener - - buffer := make([]byte, streamBufferSize) - nRead, srcAddr, err := in.ReadFrom(buffer) - if err != nil { - return nil, err - } - - if nRead == 0 { - return nil, io.ErrShortBuffer - } - - udpl.mu.Lock() - defer udpl.mu.Unlock() - - key := srcAddr.String() - conn, ok := udpl.connMap.Load(key) - if !ok { - dialer := &net.Dialer{Timeout: streamDialTimeout} - srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String()) - if err != nil { - return nil, err - } - udpl.connMap.Store(key, srcConn) - } - return conn, nil -} - -// Close implements net.Listener. -func (route *UDPListener) Close() error { - route.connMap.RangeAllParallel(func(key string, conn net.Conn) { - conn.Close() - }) - route.connMap.Clear() - return route.listener.Close() -} diff --git a/internal/server/server.go b/internal/server/server.go index 0d917ec..f6f4bf5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "errors" + "io" "log" + "net/http" "time" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/autocert" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" ) @@ -23,6 +26,8 @@ type Server struct { startTime time.Time task task.Task + + l zerolog.Logger } type Options struct { @@ -34,22 +39,11 @@ type Options struct { Handler http.Handler } -type LogrusWrapper struct { - *logrus.Entry -} - -func (l LogrusWrapper) Write(b []byte) (int, error) { - return l.Logger.WriterLevel(logrus.ErrorLevel).Write(b) -} - func NewServer(opt Options) (s *Server) { var httpSer, httpsSer *http.Server var httpHandler http.Handler - logger := log.Default() - logger.SetOutput(LogrusWrapper{ - logrus.WithFields(logrus.Fields{"?": "server", "name": opt.Name}), - }) + logger := logging.With().Str("module", "server").Str("name", opt.Name).Logger() certAvailable := false if opt.CertProvider != nil { @@ -67,14 +61,14 @@ func NewServer(opt Options) (s *Server) { httpSer = &http.Server{ Addr: opt.HTTPAddr, Handler: httpHandler, - ErrorLog: logger, + ErrorLog: log.New(io.Discard, "", 0), // most are tls related } } if certAvailable && opt.HTTPSAddr != "" { httpsSer = &http.Server{ Addr: opt.HTTPSAddr, Handler: opt.Handler, - ErrorLog: logger, + ErrorLog: log.New(io.Discard, "", 0), // most are tls related TLSConfig: &tls.Config{ GetCertificate: opt.CertProvider.GetCert, }, @@ -86,6 +80,7 @@ func NewServer(opt Options) (s *Server) { http: httpSer, https: httpsSer, task: task.GlobalTask(opt.Name + " server"), + l: logger, } } @@ -101,19 +96,19 @@ func (s *Server) Start() { s.startTime = time.Now() if s.http != nil { - s.httpStarted = true - logrus.Printf("starting http %s server on %s", s.Name, s.http.Addr) go func() { s.handleErr("http", s.http.ListenAndServe()) }() + s.httpStarted = true + s.l.Info().Str("addr", s.http.Addr).Msg("server started") } if s.https != nil { - s.httpsStarted = true - logrus.Printf("starting https %s server on %s", s.Name, s.https.Addr) go func() { s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath())) }() + s.httpsStarted = true + s.l.Info().Str("addr", s.https.Addr).Msgf("server started") } s.task.OnFinished("stop server", s.stop) @@ -144,7 +139,7 @@ func (s *Server) handleErr(scheme string, err error) { case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): return default: - logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err) + s.l.Fatal().Err(err).Str("scheme", scheme).Msg("server error") } } @@ -162,5 +157,3 @@ func redirectToTLSHandler(port string) http.HandlerFunc { http.Redirect(w, r, r.URL.String(), redirectCode) } } - -var logger = logrus.WithField("module", "server") diff --git a/internal/task/task.go b/internal/task/task.go index 1ff7400..d848136 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -4,14 +4,14 @@ import ( "context" "errors" "fmt" - "runtime" "strings" "sync" "time" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -100,7 +100,7 @@ type ( subtasks F.Set[*task] subTasksWg sync.WaitGroup - name, line string + name string OnFinishedFuncs []func() OnFinishedMu sync.Mutex @@ -113,6 +113,8 @@ type ( var ( ErrProgramExiting = errors.New("program exiting") ErrTaskCanceled = errors.New("task canceled") + + logger = logging.With().Str("module", "task").Logger() ) // GlobalTask returns a new Task with the given name, derived from the global context. @@ -159,14 +161,22 @@ func GlobalContextWait(timeout time.Duration) { case <-done: return case <-after: - logrus.Warn("Timeout waiting for these tasks to finish:\n" + globalTask.tree()) + logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree()) return } } } +func (t *task) trace() *zerolog.Event { + return logger.Trace().Str("name", t.name) +} + func (t *task) Name() string { - return t.name + if !common.IsTrace { + return t.name + } + parts := strings.Split(t.name, " > ") + return parts[len(parts)-1] } func (t *task) String() string { @@ -212,20 +222,18 @@ func (t *task) OnFinished(about string, fn func()) { onCompTask := GlobalTask(t.name + " > OnFinished > " + about) go t.runAllOnFinished(onCompTask) } - var file string - var line int - if common.IsTrace { - _, file, line, _ = runtime.Caller(1) - } idx := len(t.OnFinishedFuncs) wrapped := func() { defer func() { if err := recover(); err != nil { - logrus.Errorf("panic in %s > OnFinished[%d]: %q\nline %s:%d\n%v", t.name, idx, about, file, line, err) + logger.Error(). + Str("name", t.name). + Interface("err", err). + Msg("panic in " + about) } }() fn() - logrus.Tracef("line %s:%d\n%s > OnFinished[%d] done: %s", file, line, t.name, idx, about) + logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about) } t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped) } @@ -236,7 +244,7 @@ func (t *task) OnCancel(about string, fn func()) { <-t.ctx.Done() fn() onCompTask.Finish("done") - logrus.Tracef("%s > onCancel done: %s", t.name, about) + t.trace().Msg("onCancel done: " + about) }() } @@ -276,14 +284,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n parent.subTasksWg.Add(1) parent.subtasks.Add(subtask) if common.IsTrace { - _, file, line, ok := runtime.Caller(3) - if ok { - subtask.line = fmt.Sprintf("%s:%d", file, line) - } - logrus.Tracef("line %s\n%s started", subtask.line, name) + subtask.trace().Msg("started") go func() { subtask.Wait() - logrus.Tracef("%s finished: %s", subtask.Name(), subtask.FinishCause()) + subtask.trace().Msg("finished: " + subtask.FinishCause().Error()) }() } go func() { @@ -324,12 +328,6 @@ func (t *task) tree(prefix ...string) string { pre = prefix[0] sb.WriteString(pre + "- ") } - if t.line != "" { - sb.WriteString("line " + t.line + "\n") - if len(pre) > 0 { - sb.WriteString(pre + "- ") - } - } sb.WriteString(t.Name() + "\n") t.subtasks.RangeAll(func(subtask *task) { sb.WriteString(subtask.tree(pre + " ")) @@ -341,7 +339,6 @@ func (t *task) tree(prefix ...string) string { // // The map contains the following keys: // - name: the name of the task -// - line: the line number of the task, if available // - subtasks: a slice of maps, each representing a subtask // // The subtask maps contain the same keys, recursively. @@ -354,11 +351,8 @@ func (t *task) tree(prefix ...string) string { // only. func (t *task) serialize() map[string]any { m := make(map[string]any) - parts := strings.Split(t.name, ">") - m["name"] = strings.TrimSpace(parts[len(parts)-1]) - if t.line != "" { - m["line"] = t.line - } + parts := strings.Split(t.name, " > ") + m["name"] = parts[len(parts)-1] if t.subtasks.Size() > 0 { m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size()) t.subtasks.RangeAll(func(subtask *task) { diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index 56669d1..8d95b17 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -1,10 +1,10 @@ package functional import ( + "errors" "sync" "github.com/puzpuzpuz/xsync/v3" - E "github.com/yusing/go-proxy/internal/error" "gopkg.in/yaml.v3" ) @@ -12,6 +12,8 @@ type Map[KT comparable, VT any] struct { *xsync.MapOf[KT, VT] } +const minParallelSize = 4 + func NewMapOf[KT comparable, VT any](options ...func(*xsync.MapConfig)) Map[KT, VT] { return Map[KT, VT]{xsync.NewMapOf[KT, VT](options...)} } @@ -113,6 +115,11 @@ func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) { // // nothing func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { + if m.Size() < minParallelSize { + m.RangeAll(do) + return + } + var wg sync.WaitGroup m.Range(func(k KT, v VT) bool { @@ -126,6 +133,45 @@ func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { wg.Wait() } +// CollectErrors calls the given function for each key-value pair in the map, +// then returns a slice of errors collected. +func (m Map[KT, VT]) CollectErrors(do func(k KT, v VT) error) []error { + errs := make([]error, 0) + m.Range(func(k KT, v VT) bool { + if err := do(k, v); err != nil { + errs = append(errs, err) + } + return true + }) + return errs +} + +// CollectErrors calls the given function for each key-value pair in the map, +// then returns a slice of errors collected. +func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error { + if m.Size() < minParallelSize { + return m.CollectErrors(do) + } + + errs := make([]error, 0) + mu := sync.Mutex{} + wg := sync.WaitGroup{} + m.Range(func(k KT, v VT) bool { + wg.Add(1) + go func() { + if err := do(k, v); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + wg.Done() + }() + return true + }) + wg.Wait() + return errs +} + // RemoveAll removes all key-value pairs from the map where the value matches the given criteria. // // Parameters: @@ -160,12 +206,12 @@ func (m Map[KT, VT]) Has(k KT) bool { // Returns: // // error: if the unmarshaling fails -func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.Error { +func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) error { if m.Size() != 0 { - return E.FailedWhy("unmarshal from yaml", "map is not empty") + return errors.New("cannot unmarshal into non-empty map") } tmp := make(map[KT]VT) - if err := E.From(yaml.Unmarshal(data, tmp)); err != nil { + if err := yaml.Unmarshal(data, tmp); err != nil { return err } for k, v := range tmp { diff --git a/internal/utils/io.go b/internal/utils/io.go index ea79772..c60ff33 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -83,20 +83,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re } } -func (p BidirectionalPipe) Start() error { +func (p BidirectionalPipe) Start() E.Error { var wg sync.WaitGroup wg.Add(2) b := E.NewBuilder("bidirectional pipe error") go func() { - b.AddE(p.pSrcDst.Start()) + b.Add(p.pSrcDst.Start()) wg.Done() }() go func() { - b.AddE(p.pDstSrc.Start()) + b.Add(p.pDstSrc.Start()) wg.Done() }() wg.Wait() - return b.Build().Error() + return b.Error() } // Copyright 2009 The Go Authors. All rights reserved. @@ -152,18 +152,18 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) } -func LoadJSON[T any](path string, pointer *T) E.Error { - data, err := E.Check(os.ReadFile(path)) - if err.HasError() { +func LoadJSON[T any](path string, pointer *T) error { + data, err := os.ReadFile(path) + if err != nil { return err } - return E.From(json.Unmarshal(data, pointer)) + return json.Unmarshal(data, pointer) } -func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.Error { - data, err := E.Check(json.Marshal(pointer)) - if err.HasError() { +func SaveJSON[T any](path string, pointer *T, perm os.FileMode) error { + data, err := json.Marshal(pointer) + if err != nil { return err } - return E.From(os.WriteFile(path, data, perm)) + return os.WriteFile(path, data, perm) } diff --git a/internal/utils/nearest_field.go b/internal/utils/nearest_field.go new file mode 100644 index 0000000..01cae8b --- /dev/null +++ b/internal/utils/nearest_field.go @@ -0,0 +1,49 @@ +package utils + +import ( + "reflect" + + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +func NearestField(input string, s any) string { + minDistance := -1 + nearestField := "" + var fields []string + switch s := s.(type) { + case []string: + fields = s + default: + t := reflect.TypeOf(s) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + fields = make([]string, 0) + for i := 0; i < t.NumField(); i++ { + jsonTag, ok := t.Field(i).Tag.Lookup("json") + if ok { + fields = append(fields, jsonTag) + } else { + fields = append(fields, t.Field(i).Name) + } + } + } else if t.Kind() == reflect.Map { + keys := reflect.ValueOf(s).MapKeys() + fields = make([]string, len(keys)) + for i, key := range keys { + fields[i] = key.String() + } + } else { + panic("unsupported type: " + t.String()) + } + } + for _, field := range fields { + distance := strutils.LevenshteinDistance(input, field) + if minDistance == -1 || distance < minDistance { + minDistance = distance + nearestField = field + } + } + return nearestField +} diff --git a/internal/utils/schema.go b/internal/utils/schema.go index dd66849..0ca0099 100644 --- a/internal/utils/schema.go +++ b/internal/utils/schema.go @@ -1,15 +1,23 @@ package utils import ( + "sync" + "github.com/santhosh-tekuri/jsonschema" ) var ( schemaCompiler = jsonschema.NewCompiler() schemaStorage = make(map[string]*jsonschema.Schema) + schemaMu sync.Mutex ) func GetSchema(path string) *jsonschema.Schema { + if schema, ok := schemaStorage[path]; ok { + return schema + } + schemaMu.Lock() + defer schemaMu.Unlock() if schema, ok := schemaStorage[path]; ok { return schema } diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 7565a5e..7ab9ce4 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "errors" - "fmt" "reflect" "strconv" "strings" @@ -13,6 +12,7 @@ import ( "github.com/santhosh-tekuri/jsonschema" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" "gopkg.in/yaml.v3" ) @@ -23,17 +23,27 @@ type ( } ) +var ( + ErrInvalidType = E.New("invalid type") + ErrNilValue = E.New("nil") + ErrUnsettable = E.New("unsettable") + ErrUnsupportedConvertion = E.New("unsupported convertion") + ErrMapMissingColon = E.New("map missing colon") + ErrMapTooManyColons = E.New("map too many colons") + ErrUnknownField = E.New("unknown field") +) + func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { var i any err := yaml.Unmarshal(data, &i) if err != nil { - return E.FailWith("unmarshal yaml", err) + return E.From(err) } m, err := json.Marshal(i) if err != nil { - return E.FailWith("marshal json", err) + return E.From(err) } err = schema.Validate(bytes.NewReader(m)) @@ -43,14 +53,14 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { var valErr *jsonschema.ValidationError if !errors.As(err, &valErr) { - return E.UnexpectedError(err) + panic(err) } b := E.NewBuilder("yaml validation error") for _, e := range valErr.Causes { - b.Addf(e.Message) + b.Adds(e.Message) } - return b.Build() + return b.Error() } // Serialize converts the given data into a map[string]any representation. @@ -66,7 +76,7 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { // Returns: // - result: The resulting map[string]any representation of the data. // - error: An error if the data type is unsupported or if there is an error during conversion. -func Serialize(data any) (SerializedObject, E.Error) { +func Serialize(data any) (SerializedObject, error) { result := make(map[string]any) // Use reflection to inspect the data type @@ -74,7 +84,7 @@ func Serialize(data any) (SerializedObject, E.Error) { // Check if the value is valid if !value.IsValid() { - return nil, E.Invalid("data", fmt.Sprintf("type: %T", data)) + return nil, ErrInvalidType.Subjectf("%T", data) } // Dereference pointers if necessary @@ -123,7 +133,7 @@ func Serialize(data any) (SerializedObject, E.Error) { } } default: - return nil, E.Unsupported("type", value.Kind()) + return nil, errors.New("serialize: unsupported data type " + value.Kind().String()) } return result, nil @@ -139,11 +149,10 @@ func Serialize(data any) (SerializedObject, E.Error) { // The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. func Deserialize(src SerializedObject, dst any) E.Error { if src == nil { - return E.Invalid("src", "nil") + return E.Errorf("deserialize: src is %w", ErrNilValue) } - if dst == nil { - return E.Invalid("nil dst", fmt.Sprintf("type: %T", dst)) + return E.Errorf("deserialize: dst is %w", ErrNilValue) } dstV := reflect.ValueOf(dst) @@ -151,7 +160,7 @@ func Deserialize(src SerializedObject, dst any) E.Error { if dstV.Kind() == reflect.Ptr { if dstV.IsNil() { - return E.Invalid("nil dst", fmt.Sprintf("type: %T", dst)) + return E.Errorf("deserialize: dst is %w", ErrNilValue) } dstV = dstV.Elem() dstT = dstV.Type() @@ -161,7 +170,7 @@ func Deserialize(src SerializedObject, dst any) E.Error { // convert target fields to lower no-snake // then check if the field of data is in the target - // TODO: use E.Builder to collect errors from all fields + errs := E.NewBuilder("deserialize error") switch dstV.Kind() { case reflect.Struct: @@ -173,12 +182,13 @@ func Deserialize(src SerializedObject, dst any) E.Error { if field, ok := mapping[ToLowerNoSnake(k)]; ok { err := Convert(reflect.ValueOf(v), field) if err != nil { - return err.Subject(k) + errs.Add(err.Subject(k)) } } else { - return E.Unexpected("field", k).Subjectf("%T", dst) + errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dst)))) } } + return errs.Error() case reflect.Map: if dstV.IsNil() { dstV.Set(reflect.MakeMap(dstT)) @@ -187,15 +197,14 @@ func Deserialize(src SerializedObject, dst any) E.Error { tmp := reflect.New(dstT.Elem()).Elem() err := Convert(reflect.ValueOf(src[k]), tmp) if err != nil { - return err.Subject(k) + errs.Add(err.Subject(k)) } dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp) } + return errs.Error() default: - return E.Unsupported("target type", fmt.Sprintf("%T", dst)) + return ErrUnsupportedConvertion.Subject("deserialize to " + dstT.String()) } - - return nil } // Convert attempts to convert the src to dst. @@ -220,7 +229,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { } if !dst.CanSet() { - return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface())) + return ErrUnsettable.Subject(dstT.String()) } if dst.Kind() == reflect.Pointer { @@ -241,12 +250,12 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { case srcT.Kind() == reflect.Map: obj, ok := src.Interface().(SerializedObject) if !ok { - return E.TypeMismatch[SerializedObject](src.Interface()) + return ErrUnsupportedConvertion.Subject(dstT.String() + " to " + srcT.String()) } return Deserialize(obj, dst.Addr().Interface()) case srcT.Kind() == reflect.Slice: if dstT.Kind() != reflect.Slice { - return E.TypeError("slice", srcT, dstT) + return ErrUnsupportedConvertion.Subject(dstT.String() + " to slice") } newSlice := reflect.MakeSlice(dstT, 0, src.Len()) i := 0 @@ -271,7 +280,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { var ok bool // check if (*T).Convertor is implemented if converter, ok = dst.Addr().Interface().(Converter); !ok { - return E.TypeError("conversion", srcT, dstT) + return ErrUnsupportedConvertion.Subjectf("%s to %s", srcT, dstT) } return converter.ConvertFrom(src.Interface()) @@ -297,8 +306,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E } d, err := time.ParseDuration(src) if err != nil { - convErr = E.Invalid("duration", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(d)) return @@ -308,24 +316,21 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E case reflect.Bool: b, err := strconv.ParseBool(src) if err != nil { - convErr = E.Invalid("boolean", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(b)) return case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, err := strconv.ParseInt(src, 10, 64) if err != nil { - convErr = E.Invalid("int", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(i).Convert(dst.Type())) return case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i, err := strconv.ParseUint(src, 10, 64) if err != nil { - convErr = E.Invalid("uint", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(i).Convert(dst.Type())) return @@ -340,7 +345,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E case reflect.Slice: // one liner is comma separated list if len(lines) == 0 { - dst.Set(reflect.ValueOf(CommaSeperatedList(src))) + dst.Set(reflect.ValueOf(strutils.CommaSeperatedList(src))) return } sl := make([]string, 0, len(lines)) @@ -356,35 +361,36 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E tmp = sl case reflect.Map: m := make(map[string]string, len(lines)) + errs := E.NewBuilder("invalid map") for i, line := range lines { parts := strings.Split(line, ":") if len(parts) < 2 { - convErr = E.Invalid("map", "missing colon").Subjectf("line#%d", i+1).With(line) - return + errs.Add(ErrMapMissingColon.Subjectf("line %d", i+1)) } if len(parts) > 2 { - convErr = E.Invalid("map", "too many colons").Subjectf("line#%d", i+1).With(line) - return + errs.Add(ErrMapTooManyColons.Subjectf("line %d", i+1)) } k := strings.TrimSpace(parts[0]) v := strings.TrimSpace(parts[1]) m[k] = v } + if errs.HasError() { + return true, errs.Error() + } tmp = m } if tmp == nil { - convertible = false - return + return false, nil } return true, Convert(reflect.ValueOf(tmp), dst) } -func DeserializeJSON(j map[string]string, target any) E.Error { - data, err := E.Check(json.Marshal(j)) +func DeserializeJSON(j map[string]string, target any) error { + data, err := json.Marshal(j) if err != nil { return err } - return E.From(json.Unmarshal(data, target)) + return json.Unmarshal(data, target) } func ToLowerNoSnake(s string) string { diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 632d852..29433ee 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -1,6 +1,7 @@ package utils import ( + "errors" "reflect" "testing" @@ -37,14 +38,14 @@ var testStructSerialized = map[string]any{ func TestSerialize(t *testing.T) { s, err := Serialize(testStruct) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, s, testStructSerialized) } func TestDeserialize(t *testing.T) { var s S err := Deserialize(testStructSerialized, &s) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, s, testStruct) } @@ -65,42 +66,42 @@ func TestStringIntConvert(t *testing.T) { ok, err := ConvertString(s, reflect.ValueOf(&test.i8)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i8, int8(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.i16)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i16, int16(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.i32)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i32, int32(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.i64)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i64, int64(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u8)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u8, uint8(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u16)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u16, uint16(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u32)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u32, uint32(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u64)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u64, uint64(127)) } @@ -113,6 +114,8 @@ type testType struct { bar string } +var errInvalid = errors.New("invalid input type") + func (c *testType) ConvertFrom(v any) E.Error { switch v := v.(type) { case string: @@ -122,14 +125,14 @@ func (c *testType) ConvertFrom(v any) E.Error { c.foo = v return nil default: - return E.Invalid("input type", v) + return E.Errorf("%w %T", errInvalid, v) } } func TestConvertor(t *testing.T) { t.Run("string", func(t *testing.T) { m := new(testModel) - ExpectNoError(t, Deserialize(map[string]any{"Test": "bar"}, m).Error()) + ExpectNoError(t, Deserialize(map[string]any{"Test": "bar"}, m)) ExpectEqual(t, m.Test.foo, 0) ExpectEqual(t, m.Test.bar, "bar") @@ -137,7 +140,7 @@ func TestConvertor(t *testing.T) { t.Run("int", func(t *testing.T) { m := new(testModel) - ExpectNoError(t, Deserialize(map[string]any{"Test": 123}, m).Error()) + ExpectNoError(t, Deserialize(map[string]any{"Test": 123}, m)) ExpectEqual(t, m.Test.foo, 123) ExpectEqual(t, m.Test.bar, "") @@ -145,6 +148,6 @@ func TestConvertor(t *testing.T) { t.Run("invalid", func(t *testing.T) { m := new(testModel) - ExpectError(t, E.ErrInvalid, Deserialize(map[string]any{"Test": 123.456}, m).Error()) + ExpectError(t, errInvalid, Deserialize(map[string]any{"Test": 123.456}, m)) }) } diff --git a/internal/utils/string.go b/internal/utils/string.go deleted file mode 100644 index 9cbabc3..0000000 --- a/internal/utils/string.go +++ /dev/null @@ -1,35 +0,0 @@ -package utils - -import ( - "net/url" - "strconv" - "strings" - - "golang.org/x/text/cases" - "golang.org/x/text/language" -) - -func CommaSeperatedList(s string) []string { - res := strings.Split(s, ",") - for i, part := range res { - res[i] = strings.TrimSpace(part) - } - return res -} - -func Title(s string) string { - // TODO: support other languages. - return cases.Title(language.AmericanEnglish).String(s) -} - -func ExtractPort(fullURL string) (int, error) { - url, err := url.Parse(fullURL) - if err != nil { - return 0, err - } - return strconv.Atoi(url.Port()) -} - -func PortString(port uint16) string { - return strconv.FormatUint(uint64(port), 10) -} diff --git a/internal/utils/strutils/ansi/ansi.go b/internal/utils/strutils/ansi/ansi.go new file mode 100644 index 0000000..da89ba9 --- /dev/null +++ b/internal/utils/strutils/ansi/ansi.go @@ -0,0 +1,25 @@ +package ansi + +import "regexp" + +var ansiRegexp = regexp.MustCompile(`\x1b\[[0-9;]*m`) + +const ( + BrightRed = "\x1b[91m" + BrightGreen = "\x1b[92m" + BrightYellow = "\x1b[93m" + BrightCyan = "\x1b[96m" + BrightWhite = "\x1b[97m" + Bold = "\x1b[1m" + Reset = "\x1b[0m" + + HighlightRed = BrightRed + Bold + HighlightGreen = BrightGreen + Bold + HighlightYellow = BrightYellow + Bold + HighlightCyan = BrightCyan + Bold + HighlightWhite = BrightWhite + Bold +) + +func StripANSI(s string) string { + return ansiRegexp.ReplaceAllString(s, "") +} diff --git a/internal/utils/format.go b/internal/utils/strutils/format.go similarity index 88% rename from internal/utils/format.go rename to internal/utils/strutils/format.go index a77875f..c5027fc 100644 --- a/internal/utils/format.go +++ b/internal/utils/strutils/format.go @@ -1,9 +1,11 @@ -package utils +package strutils import ( "fmt" "strings" "time" + + "github.com/yusing/go-proxy/internal/utils/strutils/ansi" ) func FormatDuration(d time.Duration) string { @@ -55,6 +57,10 @@ func ParseBool(s string) bool { } } +func DoYouMean(s string) string { + return "Did you mean " + ansi.HighlightGreen + s + ansi.Reset + "?" +} + func pluralize(n int64) string { if n > 1 { return "s" diff --git a/internal/utils/strutils/strconv.go b/internal/utils/strutils/strconv.go new file mode 100644 index 0000000..a6979dc --- /dev/null +++ b/internal/utils/strutils/strconv.go @@ -0,0 +1,17 @@ +package strutils + +import ( + "errors" + "strconv" + + E "github.com/yusing/go-proxy/internal/error" +) + +func Atoi(s string) (int, E.Error) { + val, err := strconv.Atoi(s) + if err != nil { + return val, E.From(errors.Unwrap(err)).Subject(s) + } + + return val, nil +} diff --git a/internal/utils/strutils/string.go b/internal/utils/strutils/string.go new file mode 100644 index 0000000..a2af3a0 --- /dev/null +++ b/internal/utils/strutils/string.go @@ -0,0 +1,82 @@ +package strutils + +import ( + "net/url" + "strconv" + "strings" + + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +func CommaSeperatedList(s string) []string { + res := strings.Split(s, ",") + for i, part := range res { + res[i] = strings.TrimSpace(part) + } + return res +} + +func Title(s string) string { + return cases.Title(language.AmericanEnglish).String(s) +} + +func ExtractPort(fullURL string) (int, error) { + url, err := url.Parse(fullURL) + if err != nil { + return 0, err + } + return Atoi(url.Port()) +} + +func PortString(port uint16) string { + return strconv.FormatUint(uint64(port), 10) +} + +func LevenshteinDistance(a, b string) int { + if a == b { + return 0 + } + if len(a) == 0 { + return len(b) + } + if len(b) == 0 { + return len(a) + } + + v0 := make([]int, len(b)+1) + v1 := make([]int, len(b)+1) + + for i := 0; i <= len(b); i++ { + v0[i] = i + } + + for i := 0; i < len(a); i++ { + v1[0] = i + 1 + + for j := 0; j < len(b); j++ { + cost := 0 + if a[i] != b[j] { + cost = 1 + } + + v1[j+1] = min(v1[j]+1, v0[j+1]+1, v0[j]+cost) + } + + for j := 0; j <= len(b); j++ { + v0[j] = v1[j] + } + } + + return v1[len(b)] +} + +func min(a, b, c int) int { + if a < b && a < c { + return a + } + if b < a && b < c { + return b + } + return c +} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index 8331d0c..a812560 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils/ansi" ) func init() { @@ -52,6 +52,15 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) { } } +func ExpectStrEqual(t *testing.T, got string, want string) { + t.Helper() + got = ansi.StripANSI(got) + if got != want { + t.Errorf("expected:\n%v, got\n%v", want, got) + t.FailNow() + } +} + func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) { t.Helper() for _, want := range wants { @@ -97,17 +106,3 @@ func ExpectType[T any](t *testing.T, got any) (_ T) { } return got.(T) } - -func Must[T any](v T, err E.Error) T { - if err != nil { - panic(err) - } - return v -} - -func Must2[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} diff --git a/internal/watcher/config_file_watcher.go b/internal/watcher/config_file_watcher.go index 4f7c512..31087ea 100644 --- a/internal/watcher/config_file_watcher.go +++ b/internal/watcher/config_file_watcher.go @@ -1,10 +1,10 @@ package watcher import ( - "context" "sync" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/task" ) var ( @@ -17,7 +17,7 @@ func NewConfigFileWatcher(filename string) Watcher { configDirWatcherMu.Lock() defer configDirWatcherMu.Unlock() if configDirWatcher == nil { - configDirWatcher = NewDirectoryWatcher(context.Background(), common.ConfigBasePath) + configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath) } return configDirWatcher.Add(filename) } diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index d729201..d910650 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -7,13 +7,16 @@ import ( "sync" "github.com/fsnotify/fsnotify" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/task" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/events" ) type DirWatcher struct { + zerolog.Logger + dir string w *fsnotify.Watcher @@ -23,7 +26,7 @@ type DirWatcher struct { eventCh chan Event errCh chan E.Error - ctx context.Context + task task.Task } // NewDirectoryWatcher returns a DirWatcher instance. @@ -34,22 +37,26 @@ type DirWatcher struct { // // Note that the returned DirWatcher is not ready to use until the goroutine // started by NewDirectoryWatcher has finished. -func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher { +func NewDirectoryWatcher(callerSubtask task.Task, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() if err != nil { - logrus.Panicf("unable to create fs watcher: %s", err) + logger.Panic().Err(err).Msg("unable to create fs watcher") } if err = w.Add(dirPath); err != nil { - logrus.Panicf("unable to create fs watcher: %s", err) + logger.Panic().Err(err).Msg("unable to create fs watcher") } helper := &DirWatcher{ + Logger: logger.With(). + Str("type", "dir"). + Str("path", dirPath). + Logger(), dir: dirPath, w: w, fwMap: F.NewMapOf[string, *fileWatcher](), eventCh: make(chan Event), errCh: make(chan E.Error), - ctx: ctx, + task: callerSubtask, } go helper.start() return helper @@ -73,14 +80,10 @@ func (h *DirWatcher) Add(relPath string) Watcher { eventCh: make(chan Event), errCh: make(chan E.Error), } - go func() { - defer func() { - close(s.eventCh) - close(s.errCh) - }() - <-h.ctx.Done() - logrus.Debugf("file watcher %s stopped", relPath) - }() + h.task.OnFinished("close file watcher for "+relPath, func() { + close(s.eventCh) + close(s.errCh) + }) h.fwMap.Store(relPath, s) return s } @@ -88,11 +91,10 @@ func (h *DirWatcher) Add(relPath string) Watcher { func (h *DirWatcher) start() { defer close(h.eventCh) defer h.w.Close() - defer logrus.Debugf("directory watcher %s stopped", h.dir) for { select { - case <-h.ctx.Done(): + case <-h.task.Context().Done(): return case fsEvent, ok := <-h.w.Events: if !ok { @@ -122,9 +124,9 @@ func (h *DirWatcher) start() { // send event to directory watcher select { case h.eventCh <- msg: - logrus.Debugf("sent event to directory watcher %s", h.dir) + h.Debug().Msg("sent event to directory watcher") default: - logrus.Debugf("failed to send event to directory watcher %s", h.dir) + h.Debug().Msg("failed to send event to directory watcher") } // send event to file watcher too @@ -132,12 +134,12 @@ func (h *DirWatcher) start() { if ok { select { case w.eventCh <- msg: - logrus.Debugf("sent event to file watcher %s", relPath) + h.Debug().Msg("sent event to file watcher " + relPath) default: - logrus.Debugf("failed to send event to file watcher %s", relPath) + h.Debug().Msg("failed to send event to file watcher " + relPath) } } else { - logrus.Debugf("file watcher not found: %s", relPath) + h.Debug().Msg("file watcher not found: " + relPath) } case err := <-h.w.Errors: if errors.Is(err, fsnotify.ErrClosed) { diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index 59bbe6b..5cb78b3 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -2,12 +2,11 @@ package watcher import ( "context" - "fmt" "time" docker_events "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/watcher/events" @@ -15,10 +14,11 @@ import ( type ( DockerWatcher struct { + zerolog.Logger + host string client D.Client clientOwned bool - logrus.FieldLogger } DockerListOptions = docker_events.ListOptions ) @@ -47,7 +47,7 @@ var ( dockerWatcherRetryInterval = 3 * time.Second ) -func DockerrFilterContainer(nameOrID string) filters.KeyValuePair { +func DockerFilterContainerNameID(nameOrID string) filters.KeyValuePair { return filters.Arg("container", nameOrID) } @@ -55,18 +55,21 @@ func NewDockerWatcher(host string) DockerWatcher { return DockerWatcher{ host: host, clientOwned: true, - FieldLogger: (logrus. - WithField("module", "docker_watcher"). - WithField("host", host)), + Logger: logger.With(). + Str("type", "docker"). + Str("host", host). + Logger(), } } func NewDockerWatcherWithClient(client D.Client) DockerWatcher { return DockerWatcher{ client: client, - FieldLogger: (logrus. - WithField("module", "docker_watcher"). - WithField("host", client.DaemonHost()))} + Logger: logger.With(). + Str("type", "docker"). + Str("host", client.DaemonHost()). + Logger(), + } } func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) { @@ -88,7 +91,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList }() if !w.client.Connected() { - var err E.Error + var err error attempts := 0 for { w.client, err = D.ConnectClient(w.host) @@ -96,7 +99,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList break } attempts++ - errCh <- E.FailWith(fmt.Sprintf("docker connection attempt #%d", attempts), err) + errCh <- E.Errorf("docker connection attempt #%d: %w", attempts, err) select { case <-ctx.Done(): return @@ -113,14 +116,14 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList for { select { case <-ctx.Done(): - if err := E.From(ctx.Err()); err != nil && err.IsNot(context.Canceled) { + if err := E.From(ctx.Err()); err != nil && !err.Is(context.Canceled) { errCh <- err } return case msg := <-cEventCh: action, ok := events.DockerEventMap[msg.Action] if !ok { - w.Debugf("ignored unknown docker event: %s for container %s", msg.Action, msg.Actor.Attributes["name"]) + w.Debug().Msgf("ignored unknown docker event: %s for container %s", msg.Action, msg.Actor.Attributes["name"]) continue } event := Event{ diff --git a/internal/watcher/events/event_queue.go b/internal/watcher/events/event_queue.go index 992166e..d2926f6 100644 --- a/internal/watcher/events/event_queue.go +++ b/internal/watcher/events/event_queue.go @@ -65,7 +65,7 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { go func() { defer func() { if err := recover(); err != nil { - e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name())) + e.onError(E.Errorf("recovered panic in onFlush: %v", err).Subject(e.task.Parent().String())) } }() e.onFlush(flushTask, queue) diff --git a/internal/watcher/health/json.go b/internal/watcher/health/json.go index 0c8ec23..9123423 100644 --- a/internal/watcher/health/json.go +++ b/internal/watcher/health/json.go @@ -5,7 +5,7 @@ import ( "time" "github.com/yusing/go-proxy/internal/net/types" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type JSONRepresentation struct { @@ -27,10 +27,10 @@ func (jsonRepr *JSONRepresentation) MarshalJSON() ([]byte, error) { "name": jsonRepr.Name, "config": jsonRepr.Config, "started": jsonRepr.Started.Unix(), - "startedStr": U.FormatTime(jsonRepr.Started), + "startedStr": strutils.FormatTime(jsonRepr.Started), "status": jsonRepr.Status.String(), "uptime": jsonRepr.Uptime.Seconds(), - "uptimeStr": U.FormatDuration(jsonRepr.Uptime), + "uptimeStr": strutils.FormatDuration(jsonRepr.Uptime), "url": url, "extra": jsonRepr.Extra, }) diff --git a/internal/watcher/health/logger.go b/internal/watcher/health/logger.go index 171f4a5..3735a4c 100644 --- a/internal/watcher/health/logger.go +++ b/internal/watcher/health/logger.go @@ -1,5 +1,7 @@ package health -import "github.com/sirupsen/logrus" +import ( + "github.com/yusing/go-proxy/internal/logging" +) -var logger = logrus.WithField("module", "health_mon") +var logger = logging.With().Str("module", "health_mon").Logger() diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go index 41773ec..0a516bf 100644 --- a/internal/watcher/health/monitor.go +++ b/internal/watcher/health/monitor.go @@ -3,10 +3,14 @@ package health import ( "context" "errors" + "fmt" + "strings" "time" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -29,6 +33,10 @@ type ( var monMap = F.NewMapOf[string, HealthMonitor]() +var ( + ErrNegativeInterval = errors.New("negative interval") +) + func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { mon := &monitor{ config: config, @@ -59,10 +67,12 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error { mon.task = routeSubtask if mon.config.Interval <= 0 { - return E.Invalid("interval", mon.config.Interval) + return E.From(ErrNegativeInterval) } go func() { + logger := logging.With().Str("name", mon.service).Logger() + defer func() { if mon.status.Load() != StatusError { mon.status.Store(StatusUnknown) @@ -70,7 +80,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error { }() if err := mon.checkUpdateHealth(); err != nil { - logger.Errorf("healthchecker %s failure: %s", mon.service, err) + logger.Err(err).Msg("healthchecker failure") return } @@ -87,7 +97,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error { case <-ticker.C: err := mon.checkUpdateHealth() if err != nil { - logger.Errorf("healthchecker %s failure: %s", mon.service, err) + logger.Err(err).Msg("healthchecker failure") return } } @@ -128,10 +138,8 @@ func (mon *monitor) Uptime() time.Duration { // Name implements HealthMonitor. func (mon *monitor) Name() string { - if mon.task == nil { - return "" - } - return mon.task.Name() + parts := strings.Split(mon.service, "/") + return parts[len(parts)-1] } // String implements fmt.Stringer of HealthMonitor. @@ -151,13 +159,14 @@ func (mon *monitor) MarshalJSON() ([]byte, error) { }).MarshalJSON() } -func (mon *monitor) checkUpdateHealth() E.Error { +func (mon *monitor) checkUpdateHealth() error { + logger := logging.With().Str("name", mon.Name()).Logger() healthy, detail, err := mon.checkHealth() if err != nil { defer mon.task.Finish(err) mon.status.Store(StatusError) if !errors.Is(err, context.Canceled) { - return E.Failure("check health").With(err) + return fmt.Errorf("check health: %w", err) } return nil } @@ -169,9 +178,12 @@ func (mon *monitor) checkUpdateHealth() E.Error { } if healthy != (mon.status.Swap(status) == StatusHealthy) { if healthy { - logger.Infof("%s is up", mon.service) + logger.Info().Msg("server is up") + notif.Notify(mon.service, "server is up") } else { - logger.Warnf("%s is down: %s", mon.service, detail) + logger.Warn().Msg("server is down") + logger.Debug().Msg(detail) + notif.Notify(mon.service, "server is down") } } diff --git a/internal/watcher/logger.go b/internal/watcher/logger.go new file mode 100644 index 0000000..ad15593 --- /dev/null +++ b/internal/watcher/logger.go @@ -0,0 +1,5 @@ +package watcher + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "watcher").Logger()