From e5bbb184143dc1779450b1296b27a905179d56dd Mon Sep 17 00:00:00 2001 From: yusing Date: Tue, 29 Oct 2024 11:34:58 +0800 Subject: [PATCH] migrated from logrus to zerolog, improved error formatting, fixed concurrent map write, fixed crash on rapid page refresh for idle containers, fixed infinite recursion on gotfiy error, fixed websocket connection problem when using idlewatcher --- Makefile | 3 + cmd/main.go | 58 +-- go.mod | 7 +- go.sum | 18 +- internal/api/handler.go | 4 +- internal/api/v1/errorpage/http_handler.go | 31 -- internal/api/v1/file.go | 11 +- internal/api/v1/query/query.go | 15 +- internal/api/v1/reload.go | 6 +- internal/api/v1/stats.go | 10 +- internal/api/v1/utils/error.go | 18 +- internal/api/v1/utils/http_client.go | 11 +- internal/api/v1/utils/logging.go | 18 + internal/api/v1/utils/utils.go | 22 +- internal/autocert/config.go | 40 ++- internal/autocert/constants.go | 3 - internal/autocert/logger.go | 5 + internal/autocert/provider.go | 120 +++---- internal/autocert/provider_test/ovh_test.go | 2 +- internal/autocert/setup.go | 4 +- internal/common/args.go | 7 +- internal/common/env.go | 6 +- internal/config/config.go | 91 ++--- internal/config/query.go | 4 +- internal/docker/client.go | 27 +- internal/docker/container.go | 6 +- internal/docker/container_helper.go | 8 +- internal/docker/idlewatcher/types/config.go | 43 +-- internal/docker/idlewatcher/waker.go | 5 +- internal/docker/idlewatcher/waker_http.go | 46 +-- internal/docker/idlewatcher/waker_stream.go | 24 +- internal/docker/idlewatcher/watcher.go | 63 ++-- internal/docker/inspect.go | 12 +- internal/docker/label.go | 32 +- internal/docker/label_test.go | 26 +- internal/docker/list_containers.go | 13 +- internal/docker/logger.go | 6 +- internal/error/base.go | 46 +++ internal/error/builder.go | 156 ++++---- internal/error/builder_test.go | 62 ++-- internal/error/error.go | 336 ++---------------- internal/error/error_test.go | 210 ++++++----- internal/error/errors.go | 83 ----- internal/error/log.go | 43 +++ internal/error/nested_error.go | 120 +++++++ internal/error/subject.go | 50 +++ internal/error/utils.go | 71 ++++ internal/list-icons.go | 2 +- internal/logging/logging.go | 69 ++++ internal/net/http/dummy_response_writer.go | 15 + .../loadbalancer/dummy_response_writer.go | 15 - internal/net/http/loadbalancer/ip_hash.go | 8 +- internal/net/http/loadbalancer/least_conn.go | 4 +- .../net/http/loadbalancer/loadbalancer.go | 23 +- internal/net/http/loadbalancer/logger.go | 4 +- internal/net/http/logger.go | 5 + .../net/http/middleware/cidr_whitelist.go | 8 +- .../http/middleware/cidr_whitelist_test.go | 12 +- .../net/http/middleware/cloudflare_real_ip.go | 8 +- .../net/http/middleware/custom_error_page.go | 40 ++- .../http/middleware}/errorpage/error_page.go | 28 +- .../net/http/middleware/errorpage/logger.go | 5 + internal/net/http/middleware/forward_auth.go | 19 +- internal/net/http/middleware/logger.go | 5 + internal/net/http/middleware/middleware.go | 38 +- .../net/http/middleware/middleware_builder.go | 48 ++- .../middleware/middleware_builder_test.go | 8 +- internal/net/http/middleware/middlewares.go | 56 +-- .../net/http/middleware/modify_request.go | 6 +- .../http/middleware/modify_request_test.go | 4 +- .../net/http/middleware/modify_response.go | 6 +- .../http/middleware/modify_response_test.go | 4 +- internal/net/http/middleware/oauth2.go | 216 ++++++----- internal/net/http/middleware/real_ip.go | 6 +- internal/net/http/middleware/real_ip_test.go | 8 +- .../net/http/middleware/redirect_http_test.go | 4 +- internal/net/http/middleware/trace.go | 4 +- internal/net/http/modify_response_writer.go | 3 +- internal/net/http/reverse_proxy_mod.go | 22 +- internal/net/types/cidr.go | 9 +- internal/net/types/stream.go | 29 +- internal/notif/dispatcher.go | 70 ++-- internal/notif/gotify.go | 25 +- internal/notif/logrus.go | 21 -- internal/notif/providers.go | 3 +- internal/proxy/entry/entry.go | 12 +- internal/proxy/entry/raw.go | 9 +- internal/proxy/entry/reverse_proxy.go | 29 +- internal/proxy/entry/stream.go | 25 +- internal/proxy/fields/alias.go | 5 +- internal/proxy/fields/host.go | 6 +- internal/proxy/fields/path_pattern.go | 21 +- internal/proxy/fields/path_pattern_test.go | 6 +- internal/proxy/fields/port.go | 13 +- internal/proxy/fields/scheme.go | 6 +- internal/proxy/fields/stream_port.go | 37 +- internal/proxy/fields/stream_port_test.go | 18 +- internal/proxy/fields/stream_scheme.go | 19 +- internal/proxy/fields/stream_scheme_test.go | 37 ++ internal/route/http.go | 22 +- internal/route/logger.go | 5 + internal/route/provider/docker.go | 136 +++---- internal/route/provider/docker_test.go | 124 +++---- internal/route/provider/event_handler.go | 60 ++-- internal/route/provider/file.go | 41 +-- internal/route/provider/logger.go | 5 + internal/route/provider/provider.go | 50 +-- internal/route/raw.go | 94 ----- internal/route/route.go | 2 +- internal/route/stream.go | 69 ++-- internal/route/stream_impl.go | 115 ++++++ internal/route/tcp.go | 68 ---- internal/route/udp.go | 149 -------- internal/route/udp_forwarder.go | 197 ++++++++++ internal/route/udp_listener.go | 73 ---- internal/server/server.go | 37 +- internal/task/task.go | 56 ++- internal/utils/functional/map.go | 54 ++- internal/utils/io.go | 24 +- internal/utils/nearest_field.go | 49 +++ internal/utils/schema.go | 8 + internal/utils/serialization.go | 90 ++--- internal/utils/serialization_test.go | 31 +- internal/utils/string.go | 35 -- internal/utils/strutils/ansi/ansi.go | 25 ++ internal/utils/{ => strutils}/format.go | 8 +- internal/utils/strutils/strconv.go | 17 + internal/utils/strutils/string.go | 82 +++++ internal/utils/testing/testing.go | 25 +- internal/watcher/config_file_watcher.go | 4 +- internal/watcher/directory_watcher.go | 44 +-- internal/watcher/docker_watcher.go | 31 +- internal/watcher/events/event_queue.go | 2 +- internal/watcher/health/json.go | 6 +- internal/watcher/health/logger.go | 6 +- internal/watcher/health/monitor.go | 34 +- internal/watcher/logger.go | 5 + 137 files changed, 2640 insertions(+), 2348 deletions(-) delete mode 100644 internal/api/v1/errorpage/http_handler.go create mode 100644 internal/api/v1/utils/logging.go create mode 100644 internal/autocert/logger.go create mode 100644 internal/error/base.go delete mode 100644 internal/error/errors.go create mode 100644 internal/error/log.go create mode 100644 internal/error/nested_error.go create mode 100644 internal/error/subject.go create mode 100644 internal/error/utils.go create mode 100644 internal/logging/logging.go create mode 100644 internal/net/http/dummy_response_writer.go delete mode 100644 internal/net/http/loadbalancer/dummy_response_writer.go create mode 100644 internal/net/http/logger.go rename internal/{api/v1 => net/http/middleware}/errorpage/error_page.go (66%) create mode 100644 internal/net/http/middleware/errorpage/logger.go create mode 100644 internal/net/http/middleware/logger.go delete mode 100644 internal/notif/logrus.go create mode 100644 internal/proxy/fields/stream_scheme_test.go create mode 100644 internal/route/logger.go create mode 100644 internal/route/provider/logger.go delete mode 100644 internal/route/raw.go create mode 100644 internal/route/stream_impl.go delete mode 100755 internal/route/tcp.go delete mode 100755 internal/route/udp.go create mode 100644 internal/route/udp_forwarder.go delete mode 100644 internal/route/udp_listener.go create mode 100644 internal/utils/nearest_field.go delete mode 100644 internal/utils/string.go create mode 100644 internal/utils/strutils/ansi/ansi.go rename internal/utils/{ => strutils}/format.go (88%) create mode 100644 internal/utils/strutils/strconv.go create mode 100644 internal/utils/strutils/string.go create mode 100644 internal/watcher/logger.go diff --git a/Makefile b/Makefile index 6aa0a9e..cba758c 100755 --- a/Makefile +++ b/Makefile @@ -65,3 +65,6 @@ debug-list-containers: ci-test: mkdir -p /tmp/artifacts act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" + +cloc: + cloc --not-match-f '_test.go$$' cmd internal pkg \ No newline at end of file diff --git a/cmd/main.go b/cmd/main.go index 7a0be76..6657753 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,7 +2,6 @@ package main import ( "encoding/json" - "io" "log" "net/http" "os" @@ -10,15 +9,14 @@ import ( "syscall" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/http/middleware" - "github.com/yusing/go-proxy/internal/notif" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/task" @@ -33,44 +31,26 @@ func main() { return } - l := logrus.WithField("module", "main") - timeFmt := "01-02 15:04:05" - fullTS := true - - if common.IsTrace { - logrus.SetLevel(logrus.TraceLevel) - timeFmt = "04:05" - fullTS = false - } else if common.IsDebug { - logrus.SetLevel(logrus.DebugLevel) - } - - if args.Command != common.CommandStart { - logrus.SetOutput(io.Discard) - } else { - logrus.SetFormatter(&logrus.TextFormatter{ - DisableSorting: true, - FullTimestamp: fullTS, - ForceColors: true, - TimestampFormat: timeFmt, - }) - logrus.Infof("go-proxy version %s", pkg.GetVersion()) - logrus.AddHook(notif.GetDispatcher()) - } - if args.Command == common.CommandReload { if err := query.ReloadServer(); err != nil { - log.Fatal(err) + E.LogFatal("server reload error", err) } - log.Print("ok") + logging.Info().Msg("ok") return } - // exit if only validate config + if args.Command == common.CommandStart { + logging.Info().Msgf("go-proxy version %s", pkg.GetVersion()) + logging.Trace().Msg("trace enabled") + // logging.AddHook(notif.GetDispatcher()) + } else { + logging.DiscardLogger() + } + if args.Command == common.CommandValidate { data, err := os.ReadFile(common.ConfigPath) if err == nil { - err = config.Validate(data).Error() + err = config.Validate(data) } if err != nil { log.Fatal("config error: ", err) @@ -88,7 +68,7 @@ func main() { var cfg *config.Config var err E.Error if cfg, err = config.Load(); err != nil { - logrus.Warn(err) + E.LogWarn("errors in config", err) } switch args.Command { @@ -145,10 +125,10 @@ func main() { autocert := config.GetAutoCertProvider() if autocert != nil { if err := autocert.Setup(); err != nil { - l.Fatal(err) + E.LogFatal("autocert setup error", err) } } else { - l.Info("autocert not configured") + logging.Info().Msg("autocert not configured") } proxyServer := server.InitProxyServer(server.Options{ @@ -174,7 +154,7 @@ func main() { <-sig // grafully shutdown - logrus.Info("shutting down") + logging.Info().Msg("shutting down") task.CancelGlobalContext() task.GlobalContextWait(time.Second * time.Duration(config.Value().TimeoutShutdown)) } @@ -182,15 +162,15 @@ func main() { func prepareDirectory(dir string) { if _, err := os.Stat(dir); os.IsNotExist(err) { if err = os.MkdirAll(dir, 0o755); err != nil { - logrus.Fatalf("failed to create directory %s: %v", dir, err) + logging.Fatal().Msgf("failed to create directory %s: %v", dir, err) } } } func printJSON(obj any) { - j, err := E.Check(json.MarshalIndent(obj, "", " ")) + j, err := json.MarshalIndent(obj, "", " ") if err != nil { - logrus.Fatal(err) + logging.Fatal().Err(err).Send() } rawLogger := log.New(os.Stdout, "", 0) rawLogger.Printf("%s", j) // raw output for convenience using "jq" diff --git a/go.mod b/go.mod index 5612f4d..3370a40 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,8 @@ require ( github.com/go-acme/lego/v4 v4.19.2 github.com/gotify/server/v2 v2.5.0 github.com/puzpuzpuz/xsync/v3 v3.4.0 + github.com/rs/zerolog v1.33.0 github.com/santhosh-tekuri/jsonschema v1.2.4 - github.com/sirupsen/logrus v1.9.3 golang.org/x/net v0.30.0 golang.org/x/text v0.19.0 gopkg.in/yaml.v3 v3.0.1 @@ -20,7 +20,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cloudflare/cloudflare-go v0.107.0 // indirect + github.com/cloudflare/cloudflare-go v0.108.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect @@ -33,6 +33,8 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.62 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/term v0.5.0 // indirect @@ -42,6 +44,7 @@ require ( github.com/ovh/go-ovh v1.6.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect go.opentelemetry.io/otel v1.31.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.30.0 // indirect diff --git a/go.sum b/go.sum index f0d259b..62f4b45 100644 --- a/go.sum +++ b/go.sum @@ -4,12 +4,13 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc= -github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc= +github.com/cloudflare/cloudflare-go v0.108.0 h1:C4Skfjd8I8X3uEOGmQUT4/iGyZcWdkIU7HwvMoLkEE0= +github.com/cloudflare/cloudflare-go v0.108.0/go.mod h1:m492eNahT/9MsN7Ppnoge8AaI7QhVFtEgVm3I9HJFeU= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -40,6 +41,7 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -61,6 +63,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= @@ -88,6 +96,9 @@ github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPK github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis= github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -140,6 +151,9 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/api/handler.go b/internal/api/handler.go index 8f66864..5ee66a0 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -6,7 +6,6 @@ import ( "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" - "github.com/yusing/go-proxy/internal/api/v1/errorpage" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -38,7 +37,6 @@ func NewHandler() http.Handler { mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("GET", "/v1/stats", v1.Stats) mux.HandleFunc("GET", "/v1/stats/ws", v1.StatsWS) - mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc()) return mux } @@ -50,7 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { host, _, _ := net.SplitHostPort(r.RemoteAddr) if host != "127.0.0.1" && host != "localhost" && host != "[::1]" { - Logger.Warnf("blocked API request from %s", host) + LogWarn(r).Msgf("blocked API request from %s", host) http.Error(w, "forbidden", http.StatusForbidden) return } diff --git a/internal/api/v1/errorpage/http_handler.go b/internal/api/v1/errorpage/http_handler.go deleted file mode 100644 index 2da9372..0000000 --- a/internal/api/v1/errorpage/http_handler.go +++ /dev/null @@ -1,31 +0,0 @@ -package errorpage - -import ( - "net/http" - - . "github.com/yusing/go-proxy/internal/api/v1/utils" -) - -func GetHandleFunc() http.HandlerFunc { - setup() - return serveHTTP -} - -func serveHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if r.URL.Path == "/" { - http.Error(w, "invalid path", http.StatusNotFound) - return - } - content, ok := fileContentMap.Load(r.URL.Path) - if !ok { - http.Error(w, "404 not found", http.StatusNotFound) - return - } - if _, err := w.Write(content); err != nil { - HandleErr(w, r, err) - } -} diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index 084b4f1..acaeabb 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -39,15 +39,16 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { return } - var validateErr E.Error + var valErr E.Error if filename == common.ConfigFileName { - validateErr = config.Validate(content) + valErr = config.Validate(content) } else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) { - validateErr = provider.Validate(content) + valErr = provider.Validate(content) } + // no validation for include files - if validateErr != nil { - U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest) + if valErr != nil { + U.RespondJSON(w, r, valErr, http.StatusBadRequest) return } diff --git a/internal/api/v1/query/query.go b/internal/api/v1/query/query.go index cd30cc7..7757004 100644 --- a/internal/api/v1/query/query.go +++ b/internal/api/v1/query/query.go @@ -20,16 +20,13 @@ func ReloadServer() E.Error { } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - failure := E.Failure("server reload").Extraf("status code: %v", resp.StatusCode) - b, err := io.ReadAll(resp.Body) + failure := E.Errorf("server reload status %v", resp.StatusCode) + body, err := io.ReadAll(resp.Body) if err != nil { - return failure.Extraf("unable to read response body: %s", err) + return failure.With(err) } - reloadErr, ok := E.FromJSON(b) - if ok { - return E.Join("reload success, but server returned error", reloadErr) - } - return failure.Extraf("unable to read response body") + reloadErr := string(body) + return failure.Withf(reloadErr) } return nil } @@ -42,7 +39,7 @@ func List[T any](what string) (_ T, outErr E.Error) { } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - outErr = E.Failure("list "+what).Extraf("status code: %v", resp.StatusCode) + outErr = E.Errorf("list %s: failed, status %v", what, resp.StatusCode) return } var res T diff --git a/internal/api/v1/reload.go b/internal/api/v1/reload.go index ffd0609..42a4198 100644 --- a/internal/api/v1/reload.go +++ b/internal/api/v1/reload.go @@ -9,8 +9,8 @@ import ( func Reload(w http.ResponseWriter, r *http.Request) { if err := config.Reload(); err != nil { - U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError) - } else { - w.WriteHeader(http.StatusOK) + U.HandleErr(w, r, err) + return } + U.WriteBody(w, []byte("OK")) } diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index c86d325..e40c364 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -11,7 +11,7 @@ import ( "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/server" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) func Stats(w http.ResponseWriter, r *http.Request) { @@ -23,7 +23,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { originPats := make([]string, len(config.Value().MatchDomains)+len(localAddresses)) if len(originPats) == 0 { - U.Logger.Warnf("no match domains configured, accepting websocket request from all origins") + U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins") originPats = []string{"*"} } else { for i, domain := range config.Value().MatchDomains { @@ -38,7 +38,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { OriginPatterns: originPats, }) if err != nil { - U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err) + U.LogError(r).Err(err).Msg("failed to upgrade websocket") return } /* trunk-ignore(golangci-lint/errcheck) */ @@ -53,7 +53,7 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { for range ticker.C { stats := getStats() if err := wsjson.Write(ctx, conn, stats); err != nil { - U.Logger.Errorf("/stats/ws failed to write JSON: %s", err) + U.LogError(r).Msg("failed to write JSON") return } } @@ -62,6 +62,6 @@ func StatsWS(w http.ResponseWriter, r *http.Request) { func getStats() map[string]any { return map[string]any{ "proxies": config.Statistics(), - "uptime": utils.FormatDuration(server.GetProxyServer().Uptime()), + "uptime": strutils.FormatDuration(server.GetProxyServer().Uptime()), } } diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go index 957e6cd..3c55d95 100644 --- a/internal/api/v1/utils/error.go +++ b/internal/api/v1/utils/error.go @@ -1,37 +1,31 @@ package utils import ( - "errors" - "fmt" "net/http" - "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" ) -var Logger = logrus.WithField("module", "api") - func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) { if origErr == nil { return } - err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL) - Logger.Error(err) + LogError(r).Msg(origErr.Error()) if len(code) > 0 { - http.Error(w, err.String(), code[0]) + http.Error(w, origErr.Error(), code[0]) return } - http.Error(w, err.String(), http.StatusInternalServerError) + http.Error(w, origErr.Error(), http.StatusInternalServerError) } func ErrMissingKey(k string) error { - return errors.New("missing key '" + k + "' in query or request body") + return E.New("missing key '" + k + "' in query or request body") } func ErrInvalidKey(k string) error { - return errors.New("invalid key '" + k + "' in query or request body") + return E.New("invalid key '" + k + "' in query or request body") } func ErrNotFound(k, v string) error { - return fmt.Errorf("key %q with value %q not found", k, v) + return E.Errorf("key %q with value %q not found", k, v) } diff --git a/internal/api/v1/utils/http_client.go b/internal/api/v1/utils/http_client.go index 0cb4ebe..48d743b 100644 --- a/internal/api/v1/utils/http_client.go +++ b/internal/api/v1/utils/http_client.go @@ -9,12 +9,11 @@ import ( ) var ( - HTTPClient = &http.Client{ + httpClient = &http.Client{ Timeout: common.ConnectionTimeout, Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, DisableKeepAlives: true, - ForceAttemptHTTP2: true, + ForceAttemptHTTP2: false, DialContext: (&net.Dialer{ Timeout: common.DialTimeout, KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives @@ -23,7 +22,7 @@ var ( }, } - Get = HTTPClient.Get - Post = HTTPClient.Post - Head = HTTPClient.Head + Get = httpClient.Get + Post = httpClient.Post + Head = httpClient.Head ) diff --git a/internal/api/v1/utils/logging.go b/internal/api/v1/utils/logging.go new file mode 100644 index 0000000..cdf4222 --- /dev/null +++ b/internal/api/v1/utils/logging.go @@ -0,0 +1,18 @@ +package utils + +import ( + "net/http" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/logging" +) + +func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { + return logging.WithLevel(level).Str("module", "api"). + Str("method", r.Method). + Str("path", r.RequestURI) +} + +func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) } +func LogWarn(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.WarnLevel) } +func LogInfo(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.InfoLevel) } diff --git a/internal/api/v1/utils/utils.go b/internal/api/v1/utils/utils.go index c687b44..01015a2 100644 --- a/internal/api/v1/utils/utils.go +++ b/internal/api/v1/utils/utils.go @@ -3,6 +3,8 @@ package utils import ( "encoding/json" "net/http" + + "github.com/yusing/go-proxy/internal/logging" ) func WriteBody(w http.ResponseWriter, body []byte) { @@ -11,15 +13,25 @@ func WriteBody(w http.ResponseWriter, body []byte) { } } -func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool { +func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) (canProceed bool) { if len(code) > 0 { w.WriteHeader(code[0]) } w.Header().Set("Content-Type", "application/json") - j, err := json.MarshalIndent(data, "", " ") - if err != nil { - HandleErr(w, r, err) - return false + var j []byte + var err error + + switch data := data.(type) { + case string: + j = []byte(`"` + data + `"`) + case []byte: + j = data + default: + j, err = json.MarshalIndent(data, "", " ") + if err != nil { + logging.Panic().Err(err).Msg("failed to marshal json") + return false + } } _, err = w.Write(j) if err != nil { diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 351e479..670db93 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -8,12 +8,21 @@ import ( "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/lego" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/config/types" ) type Config types.AutoCertConfig +var ( + ErrMissingDomain = E.New("missing field 'domains'") + ErrMissingEmail = E.New("missing field 'email'") + ErrMissingProvider = E.New("missing field 'provider'") + ErrUnknownProvider = E.New("unknown provider") +) + func NewConfig(cfg *types.AutoCertConfig) *Config { if cfg.CertPath == "" { cfg.CertPath = CertFileDefault @@ -27,35 +36,36 @@ func NewConfig(cfg *types.AutoCertConfig) *Config { return (*Config)(cfg) } -func (cfg *Config) GetProvider() (provider *Provider, res E.Error) { - b := E.NewBuilder("unable to initialize autocert") - defer b.To(&res) +func (cfg *Config) GetProvider() (*Provider, E.Error) { + b := E.NewBuilder("autocert errors") if cfg.Provider != ProviderLocal { if len(cfg.Domains) == 0 { - b.Addf("%s", "no domains specified") + b.Add(ErrMissingDomain) } if cfg.Provider == "" { - b.Addf("%s", "no provider specified") + b.Add(ErrMissingProvider) } if cfg.Email == "" { - b.Addf("%s", "no email specified") + b.Add(ErrMissingEmail) } // check if provider is implemented _, ok := providersGenMap[cfg.Provider] if !ok { - b.Addf("unknown provider: %q", cfg.Provider) + b.Add(ErrUnknownProvider. + Subject(cfg.Provider). + Withf(strutils.DoYouMean(utils.NearestField(cfg.Provider, providersGenMap)))) } } if b.HasError() { - return + return nil, b.Error() } - privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) - if err.HasError() { - b.Add(E.FailWith("generate private key", err)) - return + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + b.Addf("generate private key: %w", err) + return nil, b.Error() } user := &User{ @@ -66,11 +76,9 @@ func (cfg *Config) GetProvider() (provider *Provider, res E.Error) { legoCfg := lego.NewConfig(user) legoCfg.Certificate.KeyType = certcrypto.RSA2048 - provider = &Provider{ + return &Provider{ cfg: cfg, user: user, legoCfg: legoCfg, - } - - return + }, nil } diff --git a/internal/autocert/constants.go b/internal/autocert/constants.go index 19b726f..ffe0cc3 100644 --- a/internal/autocert/constants.go +++ b/internal/autocert/constants.go @@ -7,7 +7,6 @@ import ( "github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/duckdns" "github.com/go-acme/lego/v4/providers/dns/ovh" - "github.com/sirupsen/logrus" ) const ( @@ -36,5 +35,3 @@ var providersGenMap = map[string]ProviderGenerator{ var ( ErrGetCertFailure = errors.New("get certificate failed") ) - -var logger = logrus.WithField("module", "autocert") diff --git a/internal/autocert/logger.go b/internal/autocert/logger.go new file mode 100644 index 0000000..9339fa7 --- /dev/null +++ b/internal/autocert/logger.go @@ -0,0 +1,5 @@ +package autocert + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "autocert").Logger() diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 8ab2aad..4f13673 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -17,6 +17,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( @@ -57,25 +58,20 @@ func (p *Provider) GetExpiries() CertExpiries { return p.certExpiries } -func (p *Provider) ObtainCert() (res E.Error) { - b := E.NewBuilder("failed to obtain certificate") - defer b.To(&res) - +func (p *Provider) ObtainCert() E.Error { if p.cfg.Provider == ProviderLocal { return nil } if p.client == nil { - if err := p.initClient(); err.HasError() { - b.Add(E.FailWith("init autocert client", err)) - return + if err := p.initClient(); err != nil { + return err } } if p.user.Registration == nil { - if err := p.registerACME(); err.HasError() { - b.Add(E.FailWith("register ACME", err)) - return + if err := p.registerACME(); err != nil { + return E.From(err) } } @@ -84,27 +80,23 @@ func (p *Provider) ObtainCert() (res E.Error) { Domains: p.cfg.Domains, Bundle: true, } - cert, err := E.Check(client.Certificate.Obtain(req)) - if err.HasError() { - b.Add(err) - return + cert, err := client.Certificate.Obtain(req) + if err != nil { + return E.From(err) } - if err = p.saveCert(cert); err.HasError() { - b.Add(E.FailWith("save certificate", err)) - return + if err = p.saveCert(cert); err != nil { + return E.From(err) } - tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey)) - if err.HasError() { - b.Add(E.FailWith("parse obtained certificate", err)) - return + tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey) + if err != nil { + return E.From(err) } expiries, err := getCertExpiries(&tlsCert) - if err.HasError() { - b.Add(E.FailWith("get certificate expiry", err)) - return + if err != nil { + return E.From(err) } p.tlsCert = &tlsCert p.certExpiries = expiries @@ -113,21 +105,22 @@ func (p *Provider) ObtainCert() (res E.Error) { } func (p *Provider) LoadCert() E.Error { - cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)) - if err.HasError() { - return err + cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath) + if err != nil { + return E.Errorf("load SSL certificate: %w", err) } expiries, err := getCertExpiries(&cert) - if err.HasError() { - return err + if err != nil { + return E.Errorf("parse SSL certificate: %w", err) } p.tlsCert = &cert p.certExpiries = expiries - logger.Infof("next renewal in %v", U.FormatDuration(time.Until(p.ShouldRenewOn()))) + logger.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn()))) return p.renewIfNeeded() } +// ShouldRenewOn returns the time at which the certificate should be renewed. func (p *Provider) ShouldRenewOn() time.Time { for _, expiry := range p.certExpiries { return expiry.AddDate(0, -1, 0) // 1 month before @@ -150,8 +143,8 @@ func (p *Provider) ScheduleRenewal() { case <-task.Context().Done(): return case <-ticker.C: // check every 5 seconds - if err := p.renewIfNeeded(); err.HasError() { - logger.Warn(err) + if err := p.renewIfNeeded(); err != nil { + E.LogWarn("cert renew failed", err, &logger) } } } @@ -159,31 +152,32 @@ func (p *Provider) ScheduleRenewal() { } func (p *Provider) initClient() E.Error { - legoClient, err := E.Check(lego.NewClient(p.legoCfg)) - if err.HasError() { - return E.FailWith("create lego client", err) + legoClient, err := lego.NewClient(p.legoCfg) + if err != nil { + return E.From(err) } - legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options) - if err.HasError() { - return E.FailWith("create lego provider", err) + generator := providersGenMap[p.cfg.Provider] + legoProvider, pErr := generator(p.cfg.Options) + if pErr != nil { + return pErr } - err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) - if err.HasError() { - return E.FailWith("set challenge provider", err) + err = legoClient.Challenge.SetDNS01Provider(legoProvider) + if err != nil { + return E.From(err) } p.client = legoClient return nil } -func (p *Provider) registerACME() E.Error { +func (p *Provider) registerACME() error { if p.user.Registration != nil { return nil } - reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) - if err.HasError() { + reg, err := p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) + if err != nil { return err } p.user.Registration = reg @@ -191,26 +185,27 @@ func (p *Provider) registerACME() E.Error { return nil } -func (p *Provider) saveCert(cert *certificate.Resource) E.Error { +func (p *Provider) saveCert(cert *certificate.Resource) error { /* This should have been done in setup but double check is always a good choice.*/ _, err := os.Stat(path.Dir(p.cfg.CertPath)) if err != nil { if os.IsNotExist(err) { if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil { - return E.FailWith("create cert directory", err) + return err } } else { - return E.FailWith("stat cert directory", err) + return err } } err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- if err != nil { - return E.FailWith("write key file", err) + return err } + err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r-- if err != nil { - return E.FailWith("write cert file", err) + return err } return nil } @@ -232,7 +227,7 @@ func (p *Provider) certState() CertState { sort.Strings(certDomains) if !reflect.DeepEqual(certDomains, wantedDomains) { - logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) + logger.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) return CertStateMismatch } @@ -246,25 +241,25 @@ func (p *Provider) renewIfNeeded() E.Error { switch p.certState() { case CertStateExpired: - logger.Info("certs expired, renewing") + logger.Info().Msg("certs expired, renewing") case CertStateMismatch: - logger.Info("cert domains mismatch with config, renewing") + logger.Info().Msg("cert domains mismatch with config, renewing") default: return nil } - if err := p.ObtainCert(); err.HasError() { - return E.FailWith("renew certificate", err) + if err := p.ObtainCert(); err != nil { + return err } return nil } -func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.Error) { +func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { r := make(CertExpiries, len(cert.Certificate)) for _, cert := range cert.Certificate { - x509Cert, err := E.Check(x509.ParseCertificate(cert)) - if err.HasError() { - return nil, E.FailWith("parse certificate", err) + x509Cert, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err } if x509Cert.IsCA { continue @@ -284,13 +279,10 @@ func providerGenerator[CT any, PT challenge.Provider]( return func(opt types.AutocertProviderOpt) (challenge.Provider, E.Error) { cfg := defaultCfg() err := U.Deserialize(opt, cfg) - if err.HasError() { + if err != nil { return nil, err } - p, err := E.Check(newProvider(cfg)) - if err.HasError() { - return nil, err - } - return p, nil + p, pErr := newProvider(cfg) + return p, E.From(pErr) } } diff --git a/internal/autocert/provider_test/ovh_test.go b/internal/autocert/provider_test/ovh_test.go index e9d778f..35bf059 100644 --- a/internal/autocert/provider_test/ovh_test.go +++ b/internal/autocert/provider_test/ovh_test.go @@ -45,6 +45,6 @@ oauth2_config: testYaml = testYaml[1:] // remove first \n opt := make(map[string]any) ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt)) - ExpectNoError(t, U.Deserialize(opt, cfg).Error()) + ExpectNoError(t, U.Deserialize(opt, cfg)) ExpectDeepEqual(t, cfg, cfgExpected) } diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 62640b1..c65aa35 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -11,7 +11,7 @@ func (p *Provider) Setup() (err E.Error) { if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist return err } - logger.Debug("obtaining cert due to error loading cert") + logger.Debug().Msg("obtaining cert due to error loading cert") if err = p.ObtainCert(); err != nil { return err } @@ -20,7 +20,7 @@ func (p *Provider) Setup() (err E.Error) { p.ScheduleRenewal() for _, expiry := range p.GetExpiries() { - logger.Infof("certificate expire on %s", expiry) + logger.Info().Msg("certificate expire on " + expiry.String()) break } diff --git a/internal/common/args.go b/internal/common/args.go index b0dd597..35563e7 100644 --- a/internal/common/args.go +++ b/internal/common/args.go @@ -3,8 +3,7 @@ package common import ( "flag" "fmt" - - "github.com/sirupsen/logrus" + "log" ) type Args struct { @@ -44,7 +43,7 @@ func GetArgs() Args { flag.Parse() args.Command = flag.Arg(0) if err := validateArg(args.Command); err != nil { - logrus.Fatal(err) + log.Fatalf("invalid command: %s", err) } return args } @@ -55,5 +54,5 @@ func validateArg(arg string) error { return nil } } - return fmt.Errorf("invalid command: %s", arg) + return fmt.Errorf("invalid command %q", arg) } diff --git a/internal/common/env.go b/internal/common/env.go index 13d71fe..a8ad9b7 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -7,8 +7,6 @@ import ( "os" "strconv" "strings" - - "github.com/sirupsen/logrus" ) var ( @@ -40,7 +38,7 @@ func GetEnvBool(key string, defaultValue bool) bool { } b, err := strconv.ParseBool(value) if err != nil { - log.Fatalf("Invalid boolean value: %s", value) + log.Fatalf("env %s: invalid boolean value: %s", key, value) } return b } @@ -57,7 +55,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str addr = GetEnv(key, defaultValue) host, port, err := net.SplitHostPort(addr) if err != nil { - logrus.Fatalf("Invalid address: %s", addr) + log.Fatalf("env %s: invalid address: %s", key, addr) } if host == "" { host = "localhost" diff --git a/internal/config/config.go b/internal/config/config.go index 51a42cf..61e1f7e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,14 +2,15 @@ package config import ( "os" + "strconv" "sync" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config/types" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/route" proxy "github.com/yusing/go-proxy/internal/route/provider" @@ -31,7 +32,7 @@ type Config struct { var ( instance *Config cfgWatcher watcher.Watcher - logger = logrus.WithField("module", "config") + logger = logging.With().Str("module", "config").Logger() reloadMu sync.Mutex ) @@ -80,7 +81,7 @@ func WatchChanges() { configEventFlushInterval, OnConfigChange, func(err E.Error) { - logger.Error(err) + E.LogError("config reload error", err, &logger) }, ) eventQueue.Start(cfgWatcher.Events(task.Context())) @@ -93,15 +94,16 @@ func OnConfigChange(flushTask task.Task, ev []events.Event) { // just reload once and check the last event switch ev[len(ev)-1].Action { case events.ActionFileRenamed: - logger.Warn(cfgRenameWarn) + logger.Warn().Msg(cfgRenameWarn) return case events.ActionFileDeleted: - logger.Warn(cfgDeleteWarn) + logger.Warn().Msg(cfgDeleteWarn) return } if err := Reload(); err != nil { - logger.Error(err) + // recovered in event queue + panic(err) } } @@ -138,63 +140,57 @@ func (cfg *Config) Task() task.Task { } func (cfg *Config) StartProxyProviders() { - b := E.NewBuilder("errors starting providers") - cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { - b.Add(p.Start(cfg.task.Subtask(p.String()))) - }) + errs := cfg.providers.CollectErrorsParallel( + func(_ string, p *proxy.Provider) error { + subtask := cfg.task.Subtask(p.String()) + return p.Start(subtask) + }) - if b.HasError() { - logger.Error(b.Build()) + if err := E.Join(errs...); err != nil { + E.LogError("route provider errors", err, &logger) } } -func (cfg *Config) load() (res E.Error) { - errs := E.NewBuilder("errors loading config") - defer errs.To(&res) +func (cfg *Config) load() E.Error { + const errMsg = "config load error" - logger.Debug("loading config") - defer logger.Debug("loaded config") - - data, err := E.Check(os.ReadFile(common.ConfigPath)) + data, err := os.ReadFile(common.ConfigPath) if err != nil { - errs.Add(E.FailWith("read config", err)) - logrus.Fatal(errs.Build()) + E.LogFatal(errMsg, err, &logger) } if !common.NoSchemaValidation { - if err = Validate(data); err != nil { - errs.Add(E.FailWith("schema validation", err)) - logrus.Fatal(errs.Build()) + if err := Validate(data); err != nil { + E.LogFatal(errMsg, err, &logger) } } model := types.DefaultConfig() if err := E.From(yaml.Unmarshal(data, model)); err != nil { - errs.Add(E.FailWith("parse config", err)) - logrus.Fatal(errs.Build()) + E.LogFatal(errMsg, err, &logger) } // errors are non fatal below + errs := E.NewBuilder(errMsg) errs.Add(cfg.initNotification(model.Providers.Notification)) errs.Add(cfg.initAutoCert(&model.AutoCert)) errs.Add(cfg.loadRouteProviders(&model.Providers)) cfg.value = model route.SetFindMuxDomains(model.MatchDomains) - return + return errs.Error() } func (cfg *Config) initNotification(notifCfgMap types.NotificationConfigMap) (err E.Error) { if len(notifCfgMap) == 0 { return } - errs := E.NewBuilder("errors initializing notification providers") - + errs := E.NewBuilder("notification providers load errors") for name, notifCfg := range notifCfgMap { _, err := notif.RegisterProvider(cfg.task.Subtask(name), notifCfg) errs.Add(err) } - return errs.Build() + return errs.Error() } func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) { @@ -203,40 +199,45 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Error) } cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() - if err != nil { - err = E.FailWith("autocert provider", err) - } return } -func (cfg *Config) loadRouteProviders(providers *types.Providers) (outErr E.Error) { +func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error { subtask := cfg.task.Subtask("load route providers") defer subtask.Finish("done") - errs := E.NewBuilder("errors loading route providers") - results := E.NewBuilder("loaded providers") - defer errs.To(&outErr) + errs := E.NewBuilder("route provider errors") + results := E.NewBuilder("loaded route providers") + lenLongestName := 0 for _, filename := range providers.Files { p, err := proxy.NewFileProvider(filename) if err != nil { - errs.Add(err) + errs.Add(E.PrependSubject(filename, err)) continue } cfg.providers.Store(p.GetName(), p) - errs.Add(p.LoadRoutes().Subject(filename)) - results.Addf("%d routes from %s", p.NumRoutes(), p.String()) + if len(p.GetName()) > lenLongestName { + lenLongestName = len(p.GetName()) + } } for name, dockerHost := range providers.Docker { p, err := proxy.NewDockerProvider(name, dockerHost) if err != nil { - errs.Add(err.Subjectf("%s (%s)", name, dockerHost)) + errs.Add(E.PrependSubject(name, err)) continue } cfg.providers.Store(p.GetName(), p) - errs.Add(p.LoadRoutes().Subject(p.GetName())) - results.Addf("%d routes from %s", p.NumRoutes(), p.String()) + if len(p.GetName()) > lenLongestName { + lenLongestName = len(p.GetName()) + } } - logger.Info(results.Build()) - return + cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { + if err := p.LoadRoutes(); err != nil { + errs.Add(err.Subject(p.String())) + } + results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.GetName(), p.NumRoutes()) + }) + logger.Info().Msg(results.String()) + return errs.Error() } diff --git a/internal/config/query.go b/internal/config/query.go index 813b8c9..8933cf3 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -9,8 +9,8 @@ import ( "github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/route" proxy "github.com/yusing/go-proxy/internal/route/provider" - U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" ) func DumpEntries() map[string]*entry.RawEntry { @@ -61,7 +61,7 @@ func HomepageConfig() homepage.Config { } if item.Name == "" { - item.Name = U.Title( + item.Name = strutils.Title( strings.ReplaceAll( strings.ReplaceAll(alias, "-", " "), "_", " ", diff --git a/internal/docker/client.go b/internal/docker/client.go index d4cb5f4..f1ecf5b 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -1,14 +1,15 @@ package docker import ( + "errors" "net/http" "sync" "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -22,7 +23,7 @@ type ( key string refCount *U.RefCount - l logrus.FieldLogger + l zerolog.Logger } ) @@ -70,7 +71,7 @@ func (c *SharedClient) Close() error { // Returns: // - Client: the Docker client connection. // - error: an error if the connection failed. -func ConnectClient(host string) (Client, E.Error) { +func ConnectClient(host string) (Client, error) { clientMapMu.Lock() defer clientMapMu.Unlock() @@ -85,13 +86,13 @@ func ConnectClient(host string) (Client, E.Error) { switch host { case "": - return nil, E.Invalid("docker host", "empty") + return nil, errors.New("empty docker host") case common.DockerHostFromEnv: opt = clientOptEnvHost default: - helper, err := E.Check(connhelper.GetConnectionHelper(host)) - if err.HasError() { - return nil, E.UnexpectedError(err.Error()) + helper, err := connhelper.GetConnectionHelper(host) + if err != nil { + logging.Panic().Err(err).Msg("failed to get connection helper") } if helper != nil { httpClient := &http.Client{ @@ -113,9 +114,9 @@ func ConnectClient(host string) (Client, E.Error) { } } - client, err := E.Check(client.NewClientWithOpts(opt...)) + client, err := client.NewClientWithOpts(opt...) - if err.HasError() { + if err != nil { return nil, err } @@ -123,9 +124,9 @@ func ConnectClient(host string) (Client, E.Error) { Client: client, key: host, refCount: U.NewRefCounter(), - l: logger.WithField("docker_client", client.DaemonHost()), + l: logger.With().Str("address", client.DaemonHost()).Logger(), } - c.l.Debugf("client connected") + c.l.Trace().Msg("client connected") clientMap.Store(host, c) @@ -135,7 +136,7 @@ func ConnectClient(host string) (Client, E.Error) { if c.Connected() { c.Client.Close() - c.l.Debugf("client closed") + c.l.Trace().Msg("client closed") } }() return c, nil diff --git a/internal/docker/container.go b/internal/docker/container.go index d0afa9f..f87535f 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -6,8 +6,8 @@ import ( "strings" "github.com/docker/docker/api/types" - "github.com/sirupsen/logrus" U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( @@ -59,7 +59,7 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { NetworkMode: c.HostConfig.NetworkMode, Aliases: helper.getAliases(), - IsExcluded: U.ParseBool(helper.getDeleteLabel(LabelExclude)), + IsExcluded: strutils.ParseBool(helper.getDeleteLabel(LabelExclude)), IsExplicit: isExplicit, IsDatabase: helper.isDatabase(), IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), @@ -120,7 +120,7 @@ func (c *Container) setPublicIP() { } url, err := url.Parse(c.DockerHost) if err != nil { - logrus.Errorf("invalid docker host %q: %v\nfalling back to 127.0.0.1", c.DockerHost, err) + logger.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost) c.PublicIP = "127.0.0.1" return } diff --git a/internal/docker/container_helper.go b/internal/docker/container_helper.go index 12d6cc7..d3f3a63 100644 --- a/internal/docker/container_helper.go +++ b/internal/docker/container_helper.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/docker/docker/api/types" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type containerHelper struct { @@ -23,7 +23,7 @@ func (c containerHelper) getDeleteLabel(label string) string { func (c containerHelper) getAliases() []string { if l := c.getDeleteLabel(LabelAliases); l != "" { - return U.CommaSeperatedList(l) + return strutils.CommaSeperatedList(l) } return []string{c.getName()} } @@ -44,7 +44,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping { if v.PublicPort == 0 { continue } - res[U.PortString(v.PublicPort)] = v + res[strutils.PortString(v.PublicPort)] = v } return res } @@ -52,7 +52,7 @@ func (c containerHelper) getPublicPortMapping() PortMapping { func (c containerHelper) getPrivatePortMapping() PortMapping { res := make(PortMapping) for _, v := range c.Ports { - res[U.PortString(v.PrivatePort)] = v + res[strutils.PortString(v.PrivatePort)] = v } return res } diff --git a/internal/docker/idlewatcher/types/config.go b/internal/docker/idlewatcher/types/config.go index a12829f..95201c7 100644 --- a/internal/docker/idlewatcher/types/config.go +++ b/internal/docker/idlewatcher/types/config.go @@ -1,6 +1,7 @@ package types import ( + "errors" "time" "github.com/yusing/go-proxy/internal/docker" @@ -30,7 +31,7 @@ const ( StopMethodKill StopMethod = "kill" ) -func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { +func ValidateConfig(cont *docker.Container) (*Config, E.Error) { if cont == nil { return nil, nil } @@ -44,26 +45,16 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { }, nil } - b := E.NewBuilder("invalid idlewatcher config") - defer b.To(&res) + errs := E.NewBuilder("invalid idlewatcher config") - idleTimeout, err := validateDurationPostitive(cont.IdleTimeout) - b.Add(err.Subjectf("%s", "idle_timeout")) + idleTimeout := E.Collect(errs, validateDurationPostitive, cont.IdleTimeout) + wakeTimeout := E.Collect(errs, validateDurationPostitive, cont.WakeTimeout) + stopTimeout := E.Collect(errs, validateDurationPostitive, cont.StopTimeout) + stopMethod := E.Collect(errs, validateStopMethod, cont.StopMethod) + signal := E.Collect(errs, validateSignal, cont.StopSignal) - wakeTimeout, err := validateDurationPostitive(cont.WakeTimeout) - b.Add(err.Subjectf("%s", "wake_timeout")) - - stopTimeout, err := validateDurationPostitive(cont.StopTimeout) - b.Add(err.Subjectf("%s", "stop_timeout")) - - stopMethod, err := validateStopMethod(cont.StopMethod) - b.Add(err) - - signal, err := validateSignal(cont.StopSignal) - b.Add(err) - - if err := b.Build(); err != nil { - return + if errs.HasError() { + return nil, errs.Error() } return &Config{ @@ -80,33 +71,33 @@ func ValidateConfig(cont *docker.Container) (cfg *Config, res E.Error) { }, nil } -func validateDurationPostitive(value string) (time.Duration, E.Error) { +func validateDurationPostitive(value string) (time.Duration, error) { d, err := time.ParseDuration(value) if err != nil { - return 0, E.Invalid("duration", value).With(err) + return 0, err } if d < 0 { - return 0, E.Invalid("duration", "negative value") + return 0, errors.New("duration must be positive") } return d, nil } -func validateSignal(s string) (Signal, E.Error) { +func validateSignal(s string) (Signal, error) { switch s { case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", "INT", "TERM", "HUP", "QUIT": return Signal(s), nil } - return "", E.Invalid("signal", s) + return "", errors.New("invalid signal " + s) } -func validateStopMethod(s string) (StopMethod, E.Error) { +func validateStopMethod(s string) (StopMethod, error) { sm := StopMethod(s) switch sm { case StopMethodPause, StopMethodStop, StopMethodKill: return sm, nil default: - return "", E.Invalid("stop_method", sm) + return "", errors.New("invalid stop method " + s) } } diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 23630a6..6fda047 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -42,7 +42,7 @@ func newWaker(providerSubTask task.Task, entry entry.Entry, rp *gphttp.ReversePr watcher, err := registerWatcher(providerSubTask, entry, waker) if err != nil { - return nil, err + return nil, E.Errorf("register watcher: %w", err) } if rp != nil { @@ -75,6 +75,9 @@ func (w *Watcher) Start(routeSubTask task.Task) E.Error { // Finish implements health.HealthMonitor. func (w *Watcher) Finish(reason any) { + if w.stream != nil { + w.stream.Close() + } } // Name implements health.HealthMonitor. diff --git a/internal/docker/idlewatcher/waker_http.go b/internal/docker/idlewatcher/waker_http.go index b7cf6c2..2634e5f 100644 --- a/internal/docker/idlewatcher/waker_http.go +++ b/internal/docker/idlewatcher/waker_http.go @@ -7,9 +7,7 @@ import ( "strconv" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -20,21 +18,26 @@ func (w *Watcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if !shouldNext { return } - w.rp.ServeHTTP(rw, r) + select { + case <-r.Context().Done(): + return + default: + w.rp.ServeHTTP(rw, r) + } } func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldNext bool) { w.resetIdleTimer() - if r.Body != nil { - defer r.Body.Close() - } - // pass through if container is already ready if w.ready.Load() { return true } + if r.Body != nil { + defer r.Body.Close() + } + accept := gphttp.GetAccept(r.Header) acceptHTML := (r.Method == http.MethodGet && accept.AcceptHTML() || r.RequestURI == "/" && accept.IsEmpty()) @@ -49,23 +52,22 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN rw.Header().Add("Cache-Control", "must-revalidate") rw.Header().Add("Connection", "close") if _, err := rw.Write(body); err != nil { - w.l.Errorf("error writing http response: %s", err) + w.Err(err).Msg("error writing http response") } - return + return false } ctx, cancel := context.WithTimeoutCause(r.Context(), w.WakeTimeout, errors.New("wake timeout")) defer cancel() - checkCanceled := func() bool { + checkCanceled := func() (canceled bool) { select { - case <-w.task.Context().Done(): - w.l.Debugf("wake canceled: %s", context.Cause(w.task.Context())) - http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) - return true case <-ctx.Done(): - w.l.Debugf("wake canceled: %s", context.Cause(ctx)) - http.Error(rw, "Waking timed out", http.StatusGatewayTimeout) + w.WakeDebug().Str("cause", context.Cause(ctx).Error()).Msg("canceled") + return true + case <-w.task.Context().Done(): + w.WakeDebug().Str("cause", w.task.FinishCause().Error()).Msg("canceled") + http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) return true default: return false @@ -76,12 +78,12 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN return false } - w.l.Debug("wake signal received") + w.WakeTrace().Msg("signal received") err := w.wakeIfStopped() if err != nil { - w.l.Error(E.FailWith("wake", err)) + w.WakeError(err).Send() http.Error(rw, "Error waking container", http.StatusInternalServerError) - return + return false } for { @@ -92,11 +94,11 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN if w.Status() == health.StatusHealthy { w.resetIdleTimer() if isCheckRedirect { - logrus.Debugf("container %s is ready, redirecting to %s ...", w.String(), w.hc.URL()) + w.Debug().Msgf("redirecting to %s ...", w.hc.URL()) rw.WriteHeader(http.StatusOK) - return + return false } - logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) + w.Debug().Msgf("passing through to %s ...", w.hc.URL()) return true } diff --git a/internal/docker/idlewatcher/waker_stream.go b/internal/docker/idlewatcher/waker_stream.go index 1ec4174..5cd648b 100644 --- a/internal/docker/idlewatcher/waker_stream.go +++ b/internal/docker/idlewatcher/waker_stream.go @@ -7,7 +7,7 @@ import ( "net" "time" - "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -16,25 +16,25 @@ func (w *Watcher) Addr() net.Addr { return w.stream.Addr() } +// Setup implements types.Stream. func (w *Watcher) Setup() error { return w.stream.Setup() } // Accept implements types.Stream. -func (w *Watcher) Accept() (conn net.Conn, err error) { +func (w *Watcher) Accept() (conn types.StreamConn, err error) { conn, err = w.stream.Accept() if err != nil { - logrus.Errorf("accept failed with error: %s", err) return } - if err := w.wakeFromStream(); err != nil { - w.l.Error(err) + if wakeErr := w.wakeFromStream(); wakeErr != nil { + w.WakeError(wakeErr).Msg("error waking from stream") } return } // Handle implements types.Stream. -func (w *Watcher) Handle(conn net.Conn) error { +func (w *Watcher) Handle(conn types.StreamConn) error { if err := w.wakeFromStream(); err != nil { return err } @@ -54,11 +54,11 @@ func (w *Watcher) wakeFromStream() error { return nil } - w.l.Debug("wake signal received") + w.WakeDebug().Msg("wake signal received") wakeErr := w.wakeIfStopped() if wakeErr != nil { - wakeErr = fmt.Errorf("wake failed with error: %w", wakeErr) - w.l.Error(wakeErr) + wakeErr = fmt.Errorf("%s failed: %w", w.String(), wakeErr) + w.WakeError(wakeErr).Msg("wake failed") return wakeErr } @@ -69,18 +69,18 @@ func (w *Watcher) wakeFromStream() error { select { case <-w.task.Context().Done(): cause := w.task.FinishCause() - w.l.Debugf("wake canceled: %s", cause) + w.WakeDebug().Str("cause", cause.Error()).Msg("canceled") return cause case <-ctx.Done(): cause := context.Cause(ctx) - w.l.Debugf("wake canceled: %s", cause) + w.WakeDebug().Str("cause", cause.Error()).Msg("timeout") return cause default: } if w.Status() == health.StatusHealthy { w.resetIdleTimer() - logrus.Infof("container %s is ready, passing through to %s", w.String(), w.hc.URL()) + w.Debug().Msg("container is ready, passing through to " + w.hc.URL().String()) return nil } diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index e5388fc..d23502d 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -3,15 +3,15 @@ package idlewatcher import ( "context" "errors" - "fmt" "sync" "time" "github.com/docker/docker/api/types/container" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" D "github.com/yusing/go-proxy/internal/docker" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/proxy/entry" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" @@ -25,6 +25,8 @@ type ( Watcher struct { _ U.NoCopy + zerolog.Logger + *idlewatcher.Config *waker @@ -32,7 +34,6 @@ type ( stopByMethod StopCallback // send a docker command w.r.t. `stop_method` ticker *time.Ticker task task.Task - l *logrus.Entry } WakeDone <-chan error @@ -44,13 +45,12 @@ var ( watcherMap = F.NewMapOf[string, *Watcher]() watcherMapMu sync.Mutex - logger = logrus.WithField("module", "idle_watcher") + logger = logging.With().Str("module", "idle_watcher").Logger() ) const dockerReqTimeout = 3 * time.Second -func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, E.Error) { - failure := E.Failure("idle_watcher register") +func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) (*Watcher, error) { cfg := entry.IdlewatcherConfig() if cfg.IdleTimeout == 0 { @@ -71,17 +71,17 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) } client, err := D.ConnectClient(cfg.DockerHost) - if err.HasError() { - return nil, failure.With(err) + if err != nil { + return nil, err } w := &Watcher{ + Logger: logger.With().Str("name", cfg.ContainerName).Logger(), Config: cfg, waker: waker, client: client, task: providerSubtask, ticker: time.NewTicker(cfg.IdleTimeout), - l: logger.WithField("container", cfg.ContainerName), } w.stopByMethod = w.getStopCallback() watcherMap.Store(key, w) @@ -99,6 +99,23 @@ func registerWatcher(providerSubtask task.Task, entry entry.Entry, waker *waker) return w, nil } +// WakeDebug logs a debug message related to waking the container. +func (w *Watcher) WakeDebug() *zerolog.Event { + return w.Debug().Str("action", "wake") +} + +func (w *Watcher) WakeTrace() *zerolog.Event { + return w.Trace().Str("action", "wake") +} + +func (w *Watcher) WakeError(err error) *zerolog.Event { + return w.Err(err).Str("action", "wake") +} + +func (w *Watcher) LogReason(action, reason string) { + w.Info().Str("reason", reason).Msg(action) +} + func (w *Watcher) containerStop(ctx context.Context) error { return w.client.ContainerStop(ctx, w.ContainerID, container.StopOptions{ Signal: string(w.StopSignal), @@ -130,7 +147,7 @@ func (w *Watcher) containerStatus() (string, error) { defer cancel() json, err := w.client.ContainerInspect(ctx, w.ContainerID) if err != nil { - return "", fmt.Errorf("failed to inspect container: %w", err) + return "", err } return json.State.Status, nil } @@ -181,7 +198,7 @@ func (w *Watcher) getStopCallback() StopCallback { } func (w *Watcher) resetIdleTimer() { - w.l.Trace("reset idle timer") + w.Trace().Msg("reset idle timer") w.ticker.Reset(w.IdleTimeout) } @@ -190,7 +207,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), W.DockerListOptions{ Filters: W.NewDockerFilter( W.DockerFilterContainer, - W.DockerrFilterContainer(w.ContainerID), + W.DockerFilterContainerNameID(w.ContainerID), W.DockerFilterStart, W.DockerFilterStop, W.DockerFilterDie, @@ -214,7 +231,7 @@ func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask tas // // it exits only if the context is canceled, the container is destroyed, // errors occured on docker client, or route provider died (mainly caused by config reload). -func (w *Watcher) watchUntilDestroy() error { +func (w *Watcher) watchUntilDestroy() (returnCause error) { dockerWatcher := W.NewDockerWatcherWithClient(w.client) eventTask, dockerEventCh, dockerEventErrCh := w.getEventCh(dockerWatcher) defer eventTask.Finish("stopped") @@ -224,36 +241,36 @@ func (w *Watcher) watchUntilDestroy() error { case <-w.task.Context().Done(): return w.task.FinishCause() case err := <-dockerEventErrCh: - if err != nil && err.IsNot(context.Canceled) { - w.l.Error(E.FailWith("docker watcher", err)) - return err.Error() + if !err.Is(context.Canceled) { + E.LogError("idlewatcher error", err, &w.Logger) } + return err case e := <-dockerEventCh: switch { case e.Action == events.ActionContainerDestroy: w.ContainerRunning = false w.ready.Store(false) - w.l.Info("watcher stopped by container destruction") + w.LogReason("watcher stopped", "container destroyed") return errors.New("container destroyed") // create / start / unpause case e.Action.IsContainerWake(): w.ContainerRunning = true w.resetIdleTimer() - w.l.Info("container awaken") + w.Info().Msg("awaken") case e.Action.IsContainerSleep(): // stop / pause / kil w.ContainerRunning = false w.ready.Store(false) w.ticker.Stop() default: - w.l.Errorf("unexpected docker event: %s", e) + w.Error().Msg("unexpected docker event: " + e.String()) } // container name changed should also change the container id if w.ContainerName != e.ActorName { - w.l.Debugf("container renamed %s -> %s", w.ContainerName, e.ActorName) + w.Debug().Msgf("renamed %s -> %s", w.ContainerName, e.ActorName) w.ContainerName = e.ActorName } if w.ContainerID != e.ActorID { - w.l.Debugf("container id changed %s -> %s", w.ContainerID, e.ActorID) + w.Debug().Msgf("id changed %s -> %s", w.ContainerID, e.ActorID) w.ContainerID = e.ActorID // recreate event stream eventTask.Finish("recreate event stream") @@ -263,9 +280,9 @@ func (w *Watcher) watchUntilDestroy() error { w.ticker.Stop() if w.ContainerRunning { if err := w.stopByMethod(); err != nil && !errors.Is(err, context.Canceled) { - w.l.Errorf("container stop with method %q failed with error: %v", w.StopMethod, err) + w.Err(err).Msgf("container stop with method %q failed", w.StopMethod) } else { - w.l.Info("container stopped by idle timeout") + w.LogReason("container stopped", "idle timeout") } } } diff --git a/internal/docker/inspect.go b/internal/docker/inspect.go index 7220dd8..3531af5 100644 --- a/internal/docker/inspect.go +++ b/internal/docker/inspect.go @@ -4,28 +4,26 @@ import ( "context" "errors" "time" - - E "github.com/yusing/go-proxy/internal/error" ) -func Inspect(dockerHost string, containerID string) (*Container, E.Error) { +func Inspect(dockerHost string, containerID string) (*Container, error) { client, err := ConnectClient(dockerHost) defer client.Close() - if err.HasError() { - return nil, E.FailWith("connect to docker", err) + if err != nil { + return nil, err } return client.Inspect(containerID) } -func (c Client) Inspect(containerID string) (*Container, E.Error) { +func (c Client) Inspect(containerID string) (*Container, error) { ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("docker container inspect timeout")) defer cancel() json, err := c.ContainerInspect(ctx, containerID) if err != nil { - return nil, E.From(err) + return nil, err } return FromJSON(json, c.key), nil } diff --git a/internal/docker/label.go b/internal/docker/label.go index 0d3e8e4..0d13cd5 100644 --- a/internal/docker/label.go +++ b/internal/docker/label.go @@ -24,6 +24,11 @@ type ( NestedLabelMap map[string]U.SerializedObject ) +var ( + ErrApplyToNil = E.New("label value is nil") + ErrFieldNotExist = E.New("field does not exist") +) + func (l *Label) String() string { if l.Attribute == "" { return l.Namespace + "." + l.Target @@ -41,7 +46,7 @@ func (l *Label) String() string { // - error: an error if the field does not exist. func ApplyLabel[T any](obj *T, l *Label) E.Error { if obj == nil { - return E.Invalid("nil object", l) + return ErrApplyToNil.Subject(l.String()) } switch nestedLabel := l.Value.(type) { case *Label: @@ -54,7 +59,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error { } } if !field.IsValid() { - return E.NotExist("field", l.Attribute) + return ErrFieldNotExist.Subject(l.Attribute).Subject(l.String()) } dst, ok := field.Interface().(NestedLabelMap) if !ok { @@ -65,7 +70,11 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error { } else { field = field.Addr() } - return U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface()) + err := U.Deserialize(U.SerializedObject{nestedLabel.Namespace: nestedLabel.Value}, field.Interface()) + if err != nil { + return err.Subject(l.String()) + } + return nil } if dst == nil { field.Set(reflect.MakeMap(reflect.TypeFor[NestedLabelMap]())) @@ -77,18 +86,22 @@ func ApplyLabel[T any](obj *T, l *Label) E.Error { dst[nestedLabel.Namespace][nestedLabel.Attribute] = nestedLabel.Value return nil default: - return U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj) + err := U.Deserialize(U.SerializedObject{l.Attribute: l.Value}, obj) + if err != nil { + return err.Subject(l.String()) + } + return nil } } -func ParseLabel(label string, value string) (*Label, E.Error) { +func ParseLabel(label string, value string) *Label { parts := strings.Split(label, ".") if len(parts) < 2 { return &Label{ Namespace: label, Value: value, - }, nil + } } l := &Label{ @@ -104,12 +117,9 @@ func ParseLabel(label string, value string) (*Label, E.Error) { l.Attribute = parts[2] default: l.Attribute = parts[2] - nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value) - if err != nil { - return nil, err - } + nestedLabel := ParseLabel(strings.Join(parts[3:], "."), value) l.Value = nestedLabel } - return l, nil + return l } diff --git a/internal/docker/label_test.go b/internal/docker/label_test.go index 3f178ea..4591774 100644 --- a/internal/docker/label_test.go +++ b/internal/docker/label_test.go @@ -20,9 +20,8 @@ func makeLabel(ns, name, attr string) string { func TestNestedLabel(t *testing.T) { mAttr := "prop1" - pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) - ExpectNoError(t, err.Error()) - sGot := ExpectType[*Label](t, pl.Value) + lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) + sGot := ExpectType[*Label](t, lbl.Value) ExpectFalse(t, sGot == nil) ExpectEqual(t, sGot.Namespace, mName) ExpectEqual(t, sGot.Attribute, mAttr) @@ -32,10 +31,9 @@ func TestApplyNestedLabel(t *testing.T) { entry := new(struct { Middlewares NestedLabelMap `yaml:"middlewares"` }) - pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) - ExpectNoError(t, err.Error()) - err = ApplyLabel(entry, pl) - ExpectNoError(t, err.Error()) + lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) + err := ApplyLabel(entry, lbl) + ExpectNoError(t, err) middleware1, ok := entry.Middlewares[mName] ExpectTrue(t, ok) got := ExpectType[string](t, middleware1[mAttr]) @@ -52,10 +50,9 @@ func TestApplyNestedLabelExisting(t *testing.T) { entry.Middlewares[mName] = make(U.SerializedObject) entry.Middlewares[mName][checkAttr] = checkV - pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) - ExpectNoError(t, err.Error()) - err = ApplyLabel(entry, pl) - ExpectNoError(t, err.Error()) + lbl := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) + err := ApplyLabel(entry, lbl) + ExpectNoError(t, err) middleware1, ok := entry.Middlewares[mName] ExpectTrue(t, ok) got := ExpectType[string](t, middleware1[mAttr]) @@ -74,10 +71,9 @@ func TestApplyNestedLabelNoAttr(t *testing.T) { entry.Middlewares = make(NestedLabelMap) entry.Middlewares[mName] = make(U.SerializedObject) - pl, err := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v) - ExpectNoError(t, err.Error()) - err = ApplyLabel(entry, pl) - ExpectNoError(t, err.Error()) + lbl := ParseLabel(makeLabel(NSProxy, "foo", fmt.Sprintf("%s.%s", "middlewares", mName)), v) + err := ApplyLabel(entry, lbl) + ExpectNoError(t, err) _, ok := entry.Middlewares[mName] ExpectTrue(t, ok) } diff --git a/internal/docker/list_containers.go b/internal/docker/list_containers.go index 285de52..ba9b96e 100644 --- a/internal/docker/list_containers.go +++ b/internal/docker/list_containers.go @@ -8,7 +8,6 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" - E "github.com/yusing/go-proxy/internal/error" ) var listOptions = container.ListOptions{ @@ -23,19 +22,19 @@ var listOptions = container.ListOptions{ All: true, } -func ListContainers(clientHost string) ([]types.Container, E.Error) { +func ListContainers(clientHost string) ([]types.Container, error) { dockerClient, err := ConnectClient(clientHost) - if err.HasError() { - return nil, E.FailWith("connect to docker", err) + if err != nil { + return nil, err } defer dockerClient.Close() ctx, cancel := context.WithTimeoutCause(context.Background(), 3*time.Second, errors.New("list containers timeout")) defer cancel() - containers, err := E.Check(dockerClient.ContainerList(ctx, listOptions)) - if err.HasError() { - return nil, E.FailWith("list containers", err) + containers, err := dockerClient.ContainerList(ctx, listOptions) + if err != nil { + return nil, err } return containers, nil } diff --git a/internal/docker/logger.go b/internal/docker/logger.go index b86d1be..d27f2cb 100644 --- a/internal/docker/logger.go +++ b/internal/docker/logger.go @@ -1,5 +1,7 @@ package docker -import "github.com/sirupsen/logrus" +import ( + "github.com/yusing/go-proxy/internal/logging" +) -var logger = logrus.WithField("module", "docker") +var logger = logging.With().Str("module", "docker").Logger() diff --git a/internal/error/base.go b/internal/error/base.go new file mode 100644 index 0000000..d1287ab --- /dev/null +++ b/internal/error/base.go @@ -0,0 +1,46 @@ +package error + +import ( + "errors" + "fmt" +) + +// baseError is an immutable wrapper around an error. +type baseError struct { + Err error `json:"err"` +} + +func (err *baseError) Unwrap() error { + return err.Err +} + +func (err *baseError) Is(other error) bool { + if other, ok := other.(*baseError); ok { + return errors.Is(err.Err, other.Err) + } + return errors.Is(err.Err, other) +} + +func (err baseError) Subject(subject string) Error { + err.Err = PrependSubject(subject, err.Err) + return &err +} + +func (err *baseError) Subjectf(format string, args ...any) Error { + if len(args) > 0 { + return err.Subject(fmt.Sprintf(format, args...)) + } + return err.Subject(format) +} + +func (err baseError) With(extra error) Error { + return &nestedError{&err, []error{extra}} +} + +func (err baseError) Withf(format string, args ...any) Error { + return &nestedError{&err, []error{fmt.Errorf(format, args...)}} +} + +func (err *baseError) Error() string { + return err.Err.Error() +} diff --git a/internal/error/builder.go b/internal/error/builder.go index cdab20a..12aa474 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -1,104 +1,104 @@ package error import ( - "errors" "fmt" "sync" ) type Builder struct { - *builder -} - -type builder struct { - message string - errors []Error + about string + errs []error sync.Mutex } -func NewBuilder(format string, args ...any) Builder { - if len(args) > 0 { - return Builder{&builder{message: fmt.Sprintf(format, args...)}} +func NewBuilder(about string) *Builder { + return &Builder{about: about} +} + +func (b *Builder) About() string { + if !b.HasError() { + return "" } - return Builder{&builder{message: format}} + return b.about +} + +//go:inline +func (b *Builder) HasError() bool { + return len(b.errs) > 0 +} + +func (b *Builder) Error() Error { + if !b.HasError() { + return nil + } + if len(b.errs) == 1 { + return From(b.errs[0]) + } + return &nestedError{Err: New(b.about), Extras: b.errs} +} + +func (b *Builder) String() string { + if !b.HasError() { + return "" + } + return (&nestedError{Err: New(b.about), Extras: b.errs}).Error() } // Add adds an error to the Builder. // // adding nil is no-op, -// -// flatten is a boolean flag to flatten the NestedError. -func (b Builder) Add(err Error, flatten ...bool) { - if err != nil { - b.Lock() - if len(flatten) > 0 && flatten[0] { - for _, e := range err.extras { - b.errors = append(b.errors, &e) - } +func (b *Builder) Add(err error) *Builder { + if err == nil { + return b + } + + b.Lock() + defer b.Unlock() + + switch err := err.(type) { + case *baseError: + b.errs = append(b.errs, err.Err) + case *nestedError: + if err.Err == nil { + b.errs = append(b.errs, err.Extras...) } else { - b.errors = append(b.errors, err) + b.errs = append(b.errs, err) } - b.Unlock() - } -} - -func (b Builder) AddE(err error) { - b.Add(From(err)) -} - -func (b Builder) Addf(format string, args ...any) { - if len(args) > 0 { - b.Add(errorf(format, args...)) - } else { - b.AddE(errors.New(format)) - } -} - -func (b Builder) AddRange(errs ...Error) { - b.Lock() - defer b.Unlock() - for _, err := range errs { - b.errors = append(b.errors, err) - } -} - -func (b Builder) AddRangeE(errs ...error) { - b.Lock() - defer b.Unlock() - for _, err := range errs { - b.errors = append(b.errors, From(err)) - } -} - -// Build builds a NestedError based on the errors collected in the Builder. -// -// If there are no errors in the Builder, it returns a Nil() NestedError. -// Otherwise, it returns a NestedError with the message and the errors collected. -// -// Returns: -// - NestedError: the built NestedError. -func (b Builder) Build() Error { - if len(b.errors) == 0 { - return nil - } - return Join(b.message, b.errors...) -} - -func (b Builder) To(ptr *Error) { - switch { - case ptr == nil: - return - case *ptr == nil: - *ptr = b.Build() default: - (*ptr).extras = append((*ptr).extras, *b.Build()) + b.errs = append(b.errs, err) } + + return b } -func (b Builder) String() string { - return b.Build().String() +func (b *Builder) Adds(err string) *Builder { + b.Lock() + defer b.Unlock() + b.errs = append(b.errs, newError(err)) + return b } -func (b Builder) HasError() bool { - return len(b.errors) > 0 +func (b *Builder) Addf(format string, args ...any) *Builder { + if len(args) > 0 { + b.Lock() + defer b.Unlock() + b.errs = append(b.errs, fmt.Errorf(format, args...)) + } else { + b.Adds(format) + } + + return b +} + +func (b *Builder) AddRange(errs ...error) *Builder { + b.Lock() + defer b.Unlock() + + for _, err := range errs { + if err != nil { + b.errs = append(b.errs, err) + } + } + + return b } diff --git a/internal/error/builder_test.go b/internal/error/builder_test.go index 2c24948..ec22615 100644 --- a/internal/error/builder_test.go +++ b/internal/error/builder_test.go @@ -1,6 +1,9 @@ package error_test import ( + "context" + "errors" + "io" "testing" . "github.com/yusing/go-proxy/internal/error" @@ -8,14 +11,13 @@ import ( ) func TestBuilderEmpty(t *testing.T) { - eb := NewBuilder("qwer") - ExpectTrue(t, eb.Build() == nil) - ExpectTrue(t, eb.Build().NoError()) + eb := NewBuilder("foo") + ExpectTrue(t, errors.Is(eb.Error(), nil)) ExpectFalse(t, eb.HasError()) } func TestBuilderAddNil(t *testing.T) { - eb := NewBuilder("asdf") + eb := NewBuilder("foo") var err Error for range 3 { eb.Add(nil) @@ -23,41 +25,31 @@ func TestBuilderAddNil(t *testing.T) { for range 3 { eb.Add(err) } - ExpectTrue(t, eb.Build() == nil) - ExpectTrue(t, eb.Build().NoError()) + eb.AddRange(nil, nil, err) ExpectFalse(t, eb.HasError()) + ExpectTrue(t, eb.Error() == nil) +} + +func TestBuilderIs(t *testing.T) { + eb := NewBuilder("foo") + eb.Add(context.Canceled) + eb.Add(io.ErrShortBuffer) + ExpectTrue(t, eb.HasError()) + ExpectError(t, io.ErrShortBuffer, eb.Error()) + ExpectError(t, context.Canceled, eb.Error()) } func TestBuilderNested(t *testing.T) { - eb := NewBuilder("error occurred") - eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2"))) - eb.Add(Failure("Action 2").With(Invalid("Inner", "3"))) - - got := eb.Build().String() - expected1 := (`error occurred: - - Action 1 failed: - - invalid Inner: 1 - - invalid Inner: 2 - - Action 2 failed: - - invalid Inner: 3`) - expected2 := (`error occurred: - - Action 1 failed: - - invalid Inner: "1" - - invalid Inner: "2" - - Action 2 failed: - - invalid Inner: "3"`) - ExpectEqualAny(t, got, []string{expected1, expected2}) -} - -func TestBuilderTo(t *testing.T) { - eb := NewBuilder("error occurred") - eb.Addf("abcd") - - var err Error - eb.To(&err) - got := err.String() - expected := (`error occurred: - - abcd`) + eb := NewBuilder("action failed") + eb.Add(New("Action 1").Withf("Inner: 1").Withf("Inner: 2")) + eb.Add(New("Action 2").Withf("Inner: 3")) + got := eb.String() + expected := `action failed + • Action 1 + • Inner: 1 + • Inner: 2 + • Action 2 + • Inner: 3` ExpectEqual(t, got, expected) } diff --git a/internal/error/error.go b/internal/error/error.go index 39e7d60..e584653 100644 --- a/internal/error/error.go +++ b/internal/error/error.go @@ -1,317 +1,31 @@ package error -import ( - "encoding/json" - "errors" - "fmt" - "strings" -) +type Error interface { + error -type ( - Error = *ErrorImpl - ErrorImpl struct { - subject string - err error - extras []ErrorImpl - } - ErrorJSONMarshaller struct { - Subject string `json:"subject"` - Err string `json:"error"` - Extras []ErrorJSONMarshaller `json:"extras,omitempty"` - } -) - -func From(err error) Error { - if IsNil(err) { - return nil - } - return &ErrorImpl{err: err} + // Is is a wrapper for errors.Is when there is no sub-error. + // + // When there are sub-errors, they will also be checked. + Is(other error) bool + // With appends a sub-error to the error. + With(extra error) Error + // Withf is a wrapper for With(fmt.Errorf(format, args...)). + Withf(format string, args ...any) Error + // Subject prepends the given subject with a colon and space to the error message. + // + // If there is already a subject in the error message, the subject will be + // prepended to the existing subject with " > ". + // + // Subject empty string is ignored. + Subject(subject string) Error + // Subjectf is a wrapper for Subject(fmt.Sprintf(format, args...)). + Subjectf(format string, args ...any) Error } -func FromJSON(data []byte) (Error, bool) { - var j ErrorJSONMarshaller - if err := json.Unmarshal(data, &j); err != nil { - return nil, false - } - if j.Err == "" { - return nil, false - } - extras := make([]ErrorImpl, len(j.Extras)) - for i, e := range j.Extras { - extra, ok := fromJSONObject(e) - if !ok { - return nil, false - } - extras[i] = *extra - } - return &ErrorImpl{ - subject: j.Subject, - err: errors.New(j.Err), - extras: extras, - }, true -} - -func TryUnwrap(err error) error { - if err == nil { - return nil - } - if unwrapped := errors.Unwrap(err); unwrapped != nil { - return unwrapped - } - return err -} - -// Check is a helper function that -// convert (T, error) to (T, NestedError). -func Check[T any](obj T, err error) (T, Error) { - return obj, From(err) -} - -func Join(message string, err ...Error) Error { - extras := make([]ErrorImpl, len(err)) - nErr := 0 - for i, e := range err { - if e == nil { - continue - } - extras[i] = *e - nErr++ - } - if nErr == 0 { - return nil - } - return &ErrorImpl{ - err: errors.New(message), - extras: extras, - } -} - -func JoinE(message string, err ...error) Error { - b := NewBuilder("%s", message) - for _, e := range err { - b.AddE(e) - } - return b.Build() -} - -func IsNil(err error) bool { - return err == nil -} - -func IsNotNil(err error) bool { - return err != nil -} - -func (ne Error) String() string { - var buf strings.Builder - ne.writeToSB(&buf, 0, "") - return buf.String() -} - -func (ne Error) Is(err error) bool { - if ne == nil { - return err == nil - } - // return errors.Is(ne.err, err) - if errors.Is(ne.err, err) { - return true - } - for _, e := range ne.extras { - if e.Is(err) { - return true - } - } - return false -} - -func (ne Error) IsNot(err error) bool { - return !ne.Is(err) -} - -func (ne Error) Error() error { - if ne == nil { - return nil - } - return ne.buildError(0, "") -} - -func (ne Error) With(s any) Error { - if ne == nil { - return ne - } - var msg string - switch ss := s.(type) { - case nil: - return ne - case *ErrorImpl: - if len(ss.extras) == 1 { - ne.extras = append(ne.extras, ss.extras[0]) - return ne - } - return ne.withError(ss) - case error: - // unwrap only once - return ne.withError(From(TryUnwrap(ss))) - case string: - msg = ss - case fmt.Stringer: - return ne.appendMsg(ss.String()) - default: - return ne.appendMsg(fmt.Sprint(s)) - } - return ne.withError(From(errors.New(msg))) -} - -func (ne Error) Extraf(format string, args ...any) Error { - return ne.With(errorf(format, args...)) -} - -func (ne Error) Subject(s any, sep ...string) Error { - if ne == nil { - return ne - } - var subject string - switch ss := s.(type) { - case string: - subject = ss - case fmt.Stringer: - subject = ss.String() - default: - subject = fmt.Sprint(s) - } - switch { - case ne.subject == "": - ne.subject = subject - case len(sep) > 0: - ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject) - default: - ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject) - } - return ne -} - -func (ne Error) Subjectf(format string, args ...any) Error { - if ne == nil { - return ne - } - return ne.Subject(fmt.Sprintf(format, args...)) -} - -func (ne Error) JSONObject() ErrorJSONMarshaller { - extras := make([]ErrorJSONMarshaller, len(ne.extras)) - for i, e := range ne.extras { - extras[i] = e.JSONObject() - } - return ErrorJSONMarshaller{ - Subject: ne.subject, - Err: ne.err.Error(), - Extras: extras, - } -} - -func (ne Error) JSON() []byte { - b, err := json.MarshalIndent(ne.JSONObject(), "", " ") - if err != nil { - panic(err) - } - return b -} - -func (ne Error) NoError() bool { - return ne == nil -} - -func (ne Error) HasError() bool { - return ne != nil -} - -func errorf(format string, args ...any) Error { - for i, arg := range args { - if err, ok := arg.(error); ok { - if unwrapped := errors.Unwrap(err); unwrapped != nil { - args[i] = unwrapped - } - } - } - return From(fmt.Errorf(format, args...)) -} - -func fromJSONObject(obj ErrorJSONMarshaller) (Error, bool) { - data, err := json.Marshal(obj) - if err != nil { - return nil, false - } - return FromJSON(data) -} - -func (ne Error) withError(err Error) Error { - if ne != nil && err != nil { - ne.extras = append(ne.extras, *err) - } - return ne -} - -func (ne Error) appendMsg(msg string) Error { - if ne == nil { - return nil - } - ne.err = fmt.Errorf("%w %s", ne.err, msg) - return ne -} - -func (ne Error) writeToSB(sb *strings.Builder, level int, prefix string) { - for range level { - sb.WriteString(" ") - } - sb.WriteString(prefix) - - if ne.NoError() { - sb.WriteString("nil") - return - } - - if ne.subject != "" { - sb.WriteString(ne.subject) - sb.WriteRune(' ') - } - sb.WriteString(ne.err.Error()) - if len(ne.extras) > 0 { - sb.WriteRune(':') - for _, extra := range ne.extras { - sb.WriteRune('\n') - extra.writeToSB(sb, level+1, "- ") - } - } -} - -func (ne Error) buildError(level int, prefix string) error { - var res error - var sb strings.Builder - - for range level { - sb.WriteString(" ") - } - sb.WriteString(prefix) - - if ne.NoError() { - sb.WriteString("nil") - return errors.New(sb.String()) - } - - res = fmt.Errorf("%s%w", sb.String(), ne.err) - sb.Reset() - - if ne.subject != "" { - sb.WriteString(fmt.Sprintf(" for %q", ne.subject)) - } - if len(ne.extras) > 0 { - sb.WriteRune(':') - res = fmt.Errorf("%w%s", res, sb.String()) - for _, extra := range ne.extras { - res = errors.Join(res, extra.buildError(level+1, "- ")) - } - } else { - res = fmt.Errorf("%w%s", res, sb.String()) - } - return res +// this makes JSON marshalling work, +// as the builtin one doesn't. +type errStr string + +func (err errStr) Error() string { + return string(err) } diff --git a/internal/error/error_test.go b/internal/error/error_test.go index 5d0b9ea..0ac2902 100644 --- a/internal/error/error_test.go +++ b/internal/error/error_test.go @@ -1,107 +1,157 @@ -package error_test +package error import ( "errors" + "strings" "testing" - . "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) +func TestBaseString(t *testing.T) { + ExpectEqual(t, New("error").Error(), "error") +} + +func TestBaseWithSubject(t *testing.T) { + err := New("error") + withSubject := err.Subject("foo") + withSubjectf := err.Subjectf("%s %s", "foo", "bar") + + ExpectError(t, err, withSubject) + ExpectStrEqual(t, withSubject.Error(), "foo: error") + ExpectTrue(t, withSubject.Is(err)) + + ExpectError(t, err, withSubjectf) + ExpectStrEqual(t, withSubjectf.Error(), "foo bar: error") + ExpectTrue(t, withSubjectf.Is(err)) +} + +func TestBaseWithExtra(t *testing.T) { + err := New("error") + extra := New("bar").Subject("baz") + withExtra := err.With(extra) + + ExpectTrue(t, withExtra.Is(extra)) + ExpectTrue(t, withExtra.Is(err)) + + ExpectTrue(t, errors.Is(withExtra, extra)) + ExpectTrue(t, errors.Is(withExtra, err)) + + ExpectTrue(t, strings.Contains(withExtra.Error(), err.Error())) + ExpectTrue(t, strings.Contains(withExtra.Error(), extra.Error())) + ExpectTrue(t, strings.Contains(withExtra.Error(), "baz")) +} + +func TestBaseUnwrap(t *testing.T) { + err := errors.New("err") + wrapped := From(err) + + ExpectError(t, err, errors.Unwrap(wrapped)) +} + +func TestNestedUnwrap(t *testing.T) { + err := errors.New("err") + err2 := New("err2") + wrapped := From(err).Subject("foo").With(err2.Subject("bar")) + + unwrapper, ok := wrapped.(interface{ Unwrap() []error }) + ExpectTrue(t, ok) + + ExpectError(t, err, wrapped) + ExpectError(t, err2, wrapped) + ExpectEqual(t, len(unwrapper.Unwrap()), 2) +} + func TestErrorIs(t *testing.T) { - ExpectTrue(t, Failure("foo").Is(ErrFailure)) - ExpectTrue(t, Failure("foo").With("bar").Is(ErrFailure)) - ExpectFalse(t, Failure("foo").With("bar").Is(ErrInvalid)) - ExpectFalse(t, Failure("foo").With("bar").With("baz").Is(ErrInvalid)) + from := errors.New("error") + err := From(from) + ExpectError(t, from, err) - ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid)) - ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure)) + ExpectTrue(t, err.Is(from)) + ExpectFalse(t, err.Is(New("error"))) - ExpectFalse(t, Invalid("foo", "bar").Is(nil)) - - ExpectTrue(t, errors.Is(Failure("foo").Error(), ErrFailure)) - ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrInvalid)) - ExpectTrue(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrFailure)) - ExpectFalse(t, errors.Is(Failure("foo").With(Invalid("bar", "baz")).Error(), ErrNotExists)) + ExpectTrue(t, errors.Is(err.Subject("foo"), from)) + ExpectTrue(t, errors.Is(err.Withf("foo"), from)) + ExpectTrue(t, errors.Is(err.Subject("foo").Withf("bar"), from)) } -func TestErrorNestedIs(t *testing.T) { - var err Error - ExpectTrue(t, err.Is(nil)) +func TestErrorImmutability(t *testing.T) { + err := New("err") + err2 := New("err2") - err = Failure("some reason") - ExpectTrue(t, err.Is(ErrFailure)) - ExpectFalse(t, err.Is(ErrDuplicated)) + for range 3 { + // t.Logf("%d: %v %T %s", i, errors.Unwrap(err), err, err) + err.Subject("foo") + ExpectFalse(t, strings.Contains(err.Error(), "foo")) - err.With(Duplicated("something", "")) - ExpectTrue(t, err.Is(ErrFailure)) - ExpectTrue(t, err.Is(ErrDuplicated)) - ExpectFalse(t, err.Is(ErrInvalid)) -} + err.With(err2) + ExpectFalse(t, strings.Contains(err.Error(), "extra")) + ExpectFalse(t, err.Is(err2)) -func TestIsNil(t *testing.T) { - var err Error - ExpectTrue(t, err.Is(nil)) - ExpectTrue(t, err == nil) - ExpectTrue(t, err.NoError()) - - eb := NewBuilder("") - returnNil := func() error { - return eb.Build().Error() + err = err.Subject("bar").Withf("baz") + ExpectTrue(t, err != nil) } - ExpectTrue(t, IsNil(returnNil())) - ExpectTrue(t, returnNil() == nil) - - ExpectTrue(t, (err. - Subject("any"). - With("something"). - Extraf("foo %s", "bar")) == nil) -} - -func TestErrorSimple(t *testing.T) { - ne := Failure("foo bar") - ExpectEqual(t, ne.String(), "foo bar failed") - ne = ne.Subject("baz") - ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"") } func TestErrorWith(t *testing.T) { - ne := Failure("foo").With("bar").With("baz") - ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz") + err1 := New("err1") + err2 := New("err2") + + err3 := err1.With(err2) + + ExpectTrue(t, err3.Is(err1)) + ExpectTrue(t, err3.Is(err2)) + + err2.Subject("foo") + + ExpectTrue(t, err3.Is(err1)) + ExpectTrue(t, err3.Is(err2)) + + // check if err3 is affected by err2.Subject + ExpectFalse(t, strings.Contains(err3.Error(), "foo")) } -func TestErrorNested(t *testing.T) { - inner := Failure("inner"). - With("1"). - With("1") - inner2 := Failure("inner2"). +func TestErrorStringSimple(t *testing.T) { + errFailure := New("generic failure") + ne := errFailure.Subject("foo bar") + ExpectStrEqual(t, ne.Error(), "foo bar: generic failure") + ne = ne.Subject("baz") + ExpectStrEqual(t, ne.Error(), "baz > foo bar: generic failure") +} + +func TestErrorStringNested(t *testing.T) { + errFailure := New("generic failure") + inner := errFailure.Subject("inner"). + Withf("1"). + Withf("1") + inner2 := errFailure.Subject("inner2"). Subject("action 2"). - With("2"). - With("2") - inner3 := Failure("inner3"). + Withf("2"). + Withf("2") + inner3 := errFailure.Subject("inner3"). Subject("action 3"). - With("3"). - With("3") - ne := Failure("foo"). - With("bar"). - With("baz"). + Withf("3"). + Withf("3") + ne := errFailure. + Subject("foo"). + Withf("bar"). + Withf("baz"). With(inner). With(inner.With(inner2.With(inner3))) - want := `foo failed: - - bar - - baz - - inner failed: - - 1 - - 1 - - inner failed: - - 1 - - 1 - - inner2 failed for "action 2": - - 2 - - 2 - - inner3 failed for "action 3": - - 3 - - 3` - ExpectEqual(t, ne.String(), want) - ExpectEqual(t, ne.Error().Error(), want) + want := `foo: generic failure + • bar + • baz + • inner: generic failure + • 1 + • 1 + • inner: generic failure + • 1 + • 1 + • action 2 > inner2: generic failure + • 2 + • 2 + • action 3 > inner3: generic failure + • 3 + • 3` + ExpectStrEqual(t, ne.Error(), want) } diff --git a/internal/error/errors.go b/internal/error/errors.go deleted file mode 100644 index 4ae9214..0000000 --- a/internal/error/errors.go +++ /dev/null @@ -1,83 +0,0 @@ -package error - -import ( - stderrors "errors" - "fmt" - "reflect" -) - -var ( - ErrFailure = stderrors.New("failed") - ErrInvalid = stderrors.New("invalid") - ErrUnsupported = stderrors.New("unsupported") - ErrUnexpected = stderrors.New("unexpected") - ErrNotExists = stderrors.New("does not exist") - ErrMissing = stderrors.New("missing") - ErrDuplicated = stderrors.New("duplicated") - ErrOutOfRange = stderrors.New("out of range") - ErrTypeError = stderrors.New("type error") - ErrTypeMismatch = stderrors.New("type mismatch") - ErrPanicRecv = stderrors.New("panic recovered from") -) - -const fmtSubjectWhat = "%w %v: %q" - -func Failure(what string) Error { - return errorf("%s %w", what, ErrFailure) -} - -func FailedWhy(what string, why string) Error { - return Failure(what).With(why) -} - -func FailWith(what string, err any) Error { - return Failure(what).With(err) -} - -func Invalid(subject, what any) Error { - return errorf(fmtSubjectWhat, ErrInvalid, subject, what) -} - -func Unsupported(subject, what any) Error { - return errorf(fmtSubjectWhat, ErrUnsupported, subject, what) -} - -func Unexpected(subject, what any) Error { - return errorf(fmtSubjectWhat, ErrUnexpected, subject, what) -} - -func UnexpectedError(err error) Error { - return errorf("%w error: %w", ErrUnexpected, err) -} - -func NotExist(subject, what any) Error { - return errorf("%v %w: %v", subject, ErrNotExists, what) -} - -func Missing(subject any) Error { - return errorf("%w %v", ErrMissing, subject) -} - -func Duplicated(subject, what any) Error { - return errorf("%w %v: %v", ErrDuplicated, subject, what) -} - -func OutOfRange(subject any, value any) Error { - return errorf("%v %w: %v", subject, ErrOutOfRange, value) -} - -func TypeError(subject any, from, to reflect.Type) Error { - return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to) -} - -func TypeError2(subject any, from, to reflect.Value) Error { - return TypeError(subject, from.Type(), to.Type()) -} - -func TypeMismatch[Expect any](value any) Error { - return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) -} - -func PanicRecv(format string, args ...any) Error { - return errorf("%w %s", ErrPanicRecv, fmt.Sprintf(format, args...)) -} diff --git a/internal/error/log.go b/internal/error/log.go new file mode 100644 index 0000000..56990a8 --- /dev/null +++ b/internal/error/log.go @@ -0,0 +1,43 @@ +package error + +import ( + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/logging" +) + +func getLogger(logger ...*zerolog.Logger) *zerolog.Logger { + if len(logger) > 0 { + return logger[0] + } + return logging.GetLogger() +} + +//go:inline +func LogFatal(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Fatal().Msg(err.Error()) +} + +//go:inline +func LogError(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Error().Msg(err.Error()) +} + +//go:inline +func LogWarn(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Warn().Msg(err.Error()) +} + +//go:inline +func LogPanic(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Panic().Msg(err.Error()) +} + +//go:inline +func LogInfo(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Info().Msg(err.Error()) +} + +//go:inline +func LogDebug(msg string, err error, logger ...*zerolog.Logger) { + getLogger(logger...).Debug().Msg(err.Error()) +} diff --git a/internal/error/nested_error.go b/internal/error/nested_error.go new file mode 100644 index 0000000..a2fdc61 --- /dev/null +++ b/internal/error/nested_error.go @@ -0,0 +1,120 @@ +package error + +import ( + "errors" + "fmt" + "strings" +) + +type nestedError struct { + Err error `json:"err"` + Extras []error `json:"extras"` +} + +func (err nestedError) Subject(subject string) Error { + if err.Err == nil { + err.Err = newError(subject) + } else { + err.Err = PrependSubject(subject, err.Err) + } + return &err +} + +func (err *nestedError) Subjectf(format string, args ...any) Error { + if len(args) > 0 { + return err.Subject(fmt.Sprintf(format, args...)) + } + return err.Subject(format) +} + +func (err nestedError) With(extra error) Error { + if extra != nil { + err.Extras = append(err.Extras, extra) + } + return &err +} + +func (err nestedError) Withf(format string, args ...any) Error { + if len(args) > 0 { + err.Extras = append(err.Extras, fmt.Errorf(format, args...)) + } else { + err.Extras = append(err.Extras, newError(format)) + } + return &err +} + +func (err *nestedError) Unwrap() []error { + if err.Err == nil { + if len(err.Extras) == 0 { + return nil + } + return err.Extras + } + return append([]error{err.Err}, err.Extras...) +} + +func (err *nestedError) Is(other error) bool { + if errors.Is(err.Err, other) { + return true + } + for _, e := range err.Extras { + if errors.Is(e, other) { + return true + } + } + return false +} + +func (err *nestedError) Error() string { + return buildError(err, 0) +} + +//go:inline +func makeLine(err string, level int) string { + const bulletPrefix = "• " + const spaces = " " + + if level == 0 { + return err + } + return spaces[:2*level] + bulletPrefix + err +} + +func makeLines(errs []error, level int) []string { + if len(errs) == 0 { + return nil + } + lines := make([]string, 0, len(errs)) + for _, err := range errs { + switch err := err.(type) { + case *nestedError: + if err.Err != nil { + lines = append(lines, makeLine(err.Err.Error(), level)) + } + if extras := makeLines(err.Extras, level+1); len(extras) > 0 { + lines = append(lines, extras...) + } + default: + lines = append(lines, makeLine(err.Error(), level)) + } + } + return lines +} + +func buildError(err error, level int) string { + switch err := err.(type) { + case nil: + return makeLine("", level) + case *nestedError: + lines := make([]string, 0, 1+len(err.Extras)) + if err.Err != nil { + lines = append(lines, makeLine(err.Err.Error(), level)) + } + if extras := makeLines(err.Extras, level+1); len(extras) > 0 { + lines = append(lines, extras...) + } + return strings.Join(lines, "\n") + default: + return makeLine(err.Error(), level) + } +} diff --git a/internal/error/subject.go b/internal/error/subject.go new file mode 100644 index 0000000..c78ef2d --- /dev/null +++ b/internal/error/subject.go @@ -0,0 +1,50 @@ +package error + +import ( + "strings" + + "github.com/yusing/go-proxy/internal/utils/strutils/ansi" +) + +type withSubject struct { + Subject string `json:"subject"` + Err error `json:"err"` +} + +const subjectSep = " > " + +func highlight(subject string) string { + return ansi.HighlightRed + subject + ansi.Reset +} + +func PrependSubject(subject string, err error) *withSubject { + switch err := err.(type) { + case *withSubject: + return err.Prepend(subject) + case *baseError: + return PrependSubject(subject, err.Err) + default: + return &withSubject{subject, err} + } +} + +func (err withSubject) Prepend(subject string) *withSubject { + if subject != "" { + err.Subject = subject + subjectSep + err.Subject + } + return &err +} + +func (err *withSubject) Is(other error) bool { + return err.Err == other +} + +func (err *withSubject) Unwrap() error { + return err.Err +} + +func (err *withSubject) Error() string { + subjects := strings.Split(err.Subject, subjectSep) + subjects[len(subjects)-1] = highlight(subjects[len(subjects)-1]) + return strings.Join(subjects, subjectSep) + ": " + err.Err.Error() +} diff --git a/internal/error/utils.go b/internal/error/utils.go new file mode 100644 index 0000000..6bd6c73 --- /dev/null +++ b/internal/error/utils.go @@ -0,0 +1,71 @@ +package error + +import ( + "errors" + "fmt" +) + +var ErrInvalidErrorJson = errors.New("invalid error json") + +func newError(message string) error { + return errStr(message) +} + +func New(message string) Error { + if message == "" { + return nil + } + return &baseError{newError(message)} +} + +func Errorf(format string, args ...any) Error { + return &baseError{fmt.Errorf(format, args...)} +} + +func From(err error) Error { + if err == nil { + return nil + } + if err, ok := err.(Error); ok { + return err + } + return &baseError{err} +} + +func Must[T any](v T, err error) T { + if err != nil { + LogPanic("must failed", err) + } + return v +} + +func Join(errors ...error) Error { + n := 0 + for _, err := range errors { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + errs := make([]error, 0, n) + for _, err := range errors { + if err != nil { + errs = append(errs, err) + } + } + return &nestedError{Extras: errs} +} + +func Collect[T any, Err error, Arg any, Func func(Arg) (T, Err)](eb *Builder, fn Func, arg Arg) T { + result, err := fn(arg) + eb.Add(err) + return result +} + +func Collect2[T any, Err error, Arg1 any, Arg2 any, Func func(Arg1, Arg2) (T, Err)](eb *Builder, fn Func, arg1 Arg1, arg2 Arg2) T { + result, err := fn(arg1, arg2) + eb.Add(err) + return result +} diff --git a/internal/list-icons.go b/internal/list-icons.go index 1764901..8aff4bc 100644 --- a/internal/list-icons.go +++ b/internal/list-icons.go @@ -53,7 +53,7 @@ func ListAvailableIcons() ([]string, error) { icons = append(icons, content.Path) } } - err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error() + err = utils.SaveJSON(iconsCachePath, &icons, 0o644) if err != nil { log.Print("error saving cache", err) } diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..664ea4a --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,69 @@ +package logging + +import ( + "os" + "strings" + + "github.com/rs/zerolog" + "github.com/yusing/go-proxy/internal/common" +) + +var logger zerolog.Logger + +func init() { + var timeFmt string + var level zerolog.Level + var exclude []string + + if common.IsTrace { + timeFmt = "04:05" + level = zerolog.TraceLevel + } else if common.IsDebug { + timeFmt = "01-02 15:04" + level = zerolog.DebugLevel + } else { + timeFmt = "01-02 15:04" + level = zerolog.InfoLevel + exclude = []string{"module"} + } + + prefixLength := len(timeFmt) + 5 // level takes 3 + 2 spaces + prefix := strings.Repeat(" ", prefixLength) + + logger = zerolog.New( + zerolog.ConsoleWriter{ + Out: os.Stderr, + TimeFormat: timeFmt, + FieldsExclude: exclude, + FormatMessage: func(msgI interface{}) string { // pad spaces for each line + msg := msgI.(string) + lines := strings.Split(msg, "\n") + if len(lines) == 1 { + return msg + } + for i := 1; i < len(lines); i++ { + lines[i] = prefix + lines[i] + } + return strings.Join(lines, "\n") + }, + }, + ).Level(level).With().Timestamp().Logger() +} + +func DiscardLogger() { logger = zerolog.Nop() } + +func AddHook(h zerolog.Hook) { logger = logger.Hook(h) } + +func GetLogger() *zerolog.Logger { return &logger } +func With() zerolog.Context { return logger.With() } + +func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) } + +func Info() *zerolog.Event { return logger.Info() } +func Warn() *zerolog.Event { return logger.Warn() } +func Error() *zerolog.Event { return logger.Error() } +func Err(err error) *zerolog.Event { return logger.Err(err) } +func Debug() *zerolog.Event { return logger.Debug() } +func Fatal() *zerolog.Event { return logger.Fatal() } +func Panic() *zerolog.Event { return logger.Panic() } +func Trace() *zerolog.Event { return logger.Trace() } diff --git a/internal/net/http/dummy_response_writer.go b/internal/net/http/dummy_response_writer.go new file mode 100644 index 0000000..0e5a1a9 --- /dev/null +++ b/internal/net/http/dummy_response_writer.go @@ -0,0 +1,15 @@ +package http + +import "net/http" + +type DummyResponseWriter struct{} + +func (w DummyResponseWriter) Header() http.Header { + return make(http.Header) +} + +func (w DummyResponseWriter) Write([]byte) (_ int, _ error) { + return +} + +func (w DummyResponseWriter) WriteHeader(int) {} diff --git a/internal/net/http/loadbalancer/dummy_response_writer.go b/internal/net/http/loadbalancer/dummy_response_writer.go deleted file mode 100644 index d6ea9f0..0000000 --- a/internal/net/http/loadbalancer/dummy_response_writer.go +++ /dev/null @@ -1,15 +0,0 @@ -package loadbalancer - -import "net/http" - -type DummyResponseWriter struct{} - -func (w *DummyResponseWriter) Header() (_ http.Header) { - return -} - -func (w *DummyResponseWriter) Write([]byte) (_ int, _ error) { - return -} - -func (w *DummyResponseWriter) WriteHeader(int) {} diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index 62e1a51..f952932 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -11,20 +11,22 @@ import ( ) type ipHash struct { + *LoadBalancer + realIP *middleware.Middleware pool servers mu sync.Mutex } func (lb *LoadBalancer) newIPHash() impl { - impl := new(ipHash) + impl := &ipHash{LoadBalancer: lb} if len(lb.Options) == 0 { return impl } var err E.Error impl.realIP, err = middleware.NewRealIP(lb.Options) if err != nil { - logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err) + E.LogError("invalid real_ip options, ignoring", err, &impl.Logger) } return impl } @@ -70,7 +72,7 @@ func (impl *ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { http.Error(rw, "Internal error", http.StatusInternalServerError) - logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err) + impl.Err(err).Msg("invalid remote address " + r.RemoteAddr) return } idx := hashIP(ip) % uint32(len(impl.pool)) diff --git a/internal/net/http/loadbalancer/least_conn.go b/internal/net/http/loadbalancer/least_conn.go index 2ca9794..8fe8894 100644 --- a/internal/net/http/loadbalancer/least_conn.go +++ b/internal/net/http/loadbalancer/least_conn.go @@ -31,14 +31,14 @@ func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.R srv := srvs[0] minConn, ok := impl.nConn.Load(srv) if !ok { - logger.Errorf("[BUG] server %s not found", srv.Name) + impl.Error().Msgf("[BUG] server %s not found", srv.Name) http.Error(rw, "Internal error", http.StatusInternalServerError) } for i := 1; i < len(srvs); i++ { nConn, ok := impl.nConn.Load(srvs[i]) if !ok { - logger.Errorf("[BUG] server %s not found", srv.Name) + impl.Error().Msgf("[BUG] server %s not found", srv.Name) http.Error(rw, "Internal error", http.StatusInternalServerError) } if nConn.Load() < minConn.Load() { diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index ed860e6..4812116 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -6,9 +6,11 @@ import ( "sync" "time" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" idlewatcher "github.com/yusing/go-proxy/internal/docker/idlewatcher/types" E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/middleware" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/health" @@ -29,6 +31,8 @@ type ( Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"` } LoadBalancer struct { + zerolog.Logger + impl *Config @@ -48,6 +52,7 @@ const maxWeight weightType = 100 func New(cfg *Config) *LoadBalancer { lb := &LoadBalancer{ + Logger: logger.With().Str("name", cfg.Link).Logger(), Config: new(Config), pool: newPool(), task: task.DummyTask(), @@ -102,7 +107,7 @@ func (lb *LoadBalancer) UpdateConfigIfNeeded(cfg *Config) { if lb.Mode == Unset && cfg.Mode != Unset { lb.Mode = cfg.Mode if !lb.Mode.ValidateUpdate() { - logger.Warnf("loadbalancer %s: invalid mode %q, fallback to %q", cfg.Link, cfg.Mode, lb.Mode) + lb.Error().Msgf("invalid mode %q, fallback to %q", cfg.Mode, lb.Mode) } lb.updateImpl() } @@ -131,7 +136,11 @@ func (lb *LoadBalancer) AddServer(srv *Server) { lb.rebalance() lb.impl.OnAddServer(srv) - logger.Debugf("[add] %s to loadbalancer %s: %d servers available", srv.Name, lb.Link, lb.pool.Size()) + + lb.Debug(). + Str("action", "add"). + Str("server", srv.Name). + Msgf("%d servers available", lb.pool.Size()) } func (lb *LoadBalancer) RemoveServer(srv *Server) { @@ -148,13 +157,15 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { lb.rebalance() lb.impl.OnRemoveServer(srv) + lb.Debug(). + Str("action", "remove"). + Str("server", srv.Name). + Msgf("%d servers left", lb.pool.Size()) + if lb.pool.Size() == 0 { lb.task.Finish("no server left") - logger.Infof("loadbalancer %s stopped", lb.Link) return } - - logger.Debugf("[remove] %s from loadbalancer %s: %d servers left", srv.Name, lb.Link, lb.pool.Size()) } func (lb *LoadBalancer) rebalance() { @@ -218,7 +229,7 @@ func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) defer cancel() // send dummy request to wake all servers - var dummyRW *DummyResponseWriter + var dummyRW gphttp.DummyResponseWriter for _, srv := range srvs { // wake only if server implements Waker _, ok := srv.handler.(idlewatcher.Waker) diff --git a/internal/net/http/loadbalancer/logger.go b/internal/net/http/loadbalancer/logger.go index 7b9b51d..30fac46 100644 --- a/internal/net/http/loadbalancer/logger.go +++ b/internal/net/http/loadbalancer/logger.go @@ -1,5 +1,5 @@ package loadbalancer -import "github.com/sirupsen/logrus" +import "github.com/yusing/go-proxy/internal/logging" -var logger = logrus.WithField("module", "load_balancer") +var logger = logging.With().Str("module", "load_balancer").Logger() diff --git a/internal/net/http/logger.go b/internal/net/http/logger.go new file mode 100644 index 0000000..3e52c93 --- /dev/null +++ b/internal/net/http/logger.go @@ -0,0 +1,5 @@ +package http + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "http").Logger() diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index 5702851..ec19869 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -15,9 +15,9 @@ type cidrWhitelist struct { } type cidrWhitelistOpts struct { - Allow []*types.CIDR - StatusCode int - Message string + Allow []*types.CIDR `json:"allow"` + StatusCode int `json:"statusCode"` + Message string `json:"message"` cachedAddr F.Map[string, bool] // cache for trusted IPs } @@ -47,7 +47,7 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.Error) { return nil, err } if len(wl.cidrWhitelistOpts.Allow) == 0 { - return nil, E.Missing("allow range") + return nil, E.New("no allowed CIDRs") } return wl.m, nil } diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go index 0daeb9d..dd5fc69 100644 --- a/internal/net/http/middleware/cidr_whitelist_test.go +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -5,6 +5,7 @@ import ( "net/http" "testing" + E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -13,10 +14,9 @@ var testCIDRWhitelistCompose []byte var deny, accept *Middleware func TestCIDRWhitelist(t *testing.T) { - mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose) - if err != nil { - panic(err) - } + errs := E.NewBuilder("") + mids := BuildMiddlewaresFromYAML("", testCIDRWhitelistCompose, errs) + ExpectNoError(t, errs.Error()) deny = mids["deny@file"] accept = mids["accept@file"] if deny == nil || accept == nil { @@ -26,7 +26,7 @@ func TestCIDRWhitelist(t *testing.T) { t.Run("deny", func(t *testing.T) { for range 10 { result, err := newMiddlewareTest(deny, nil) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode) ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message) } @@ -35,7 +35,7 @@ func TestCIDRWhitelist(t *testing.T) { t.Run("accept", func(t *testing.T) { for range 10 { result, err := newMiddlewareTest(accept, nil) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusOK) } }) diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index 20e7ff3..f2c5b92 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -10,10 +10,10 @@ import ( "sync" "time" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/strutils" ) const ( @@ -26,7 +26,7 @@ const ( var ( cfCIDRsLastUpdate time.Time cfCIDRsMu sync.Mutex - cfCIDRsLogger = logrus.WithField("middleware", "CloudflareRealIP") + cfCIDRsLogger = logger.With().Str("name", "CloudflareRealIP").Logger() ) var CloudflareRealIP = &realIP{ @@ -80,13 +80,13 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { ) if err != nil { cfCIDRsLastUpdate = time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval) - cfCIDRsLogger.Errorf("failed to update cloudflare range: %s, retry in %s", err, cfCIDRsUpdateRetryInterval) + cfCIDRsLogger.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) return nil } } cfCIDRsLastUpdate = time.Now() - cfCIDRsLogger.Debugf("cloudflare CIDR range updated") + cfCIDRsLogger.Info().Msg("cloudflare CIDR range updated") return } diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index f06b686..6b764e2 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -8,24 +8,31 @@ import ( "strconv" "strings" - "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/api/v1/errorpage" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" ) -var CustomErrorPage = &Middleware{ - before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { - if !ServeStaticErrorPageFile(w, r) { - next(w, r) - } - }, - modifyResponse: func(resp *Response) error { +var CustomErrorPage *Middleware + +func init() { + CustomErrorPage = customErrorPage() +} + +func customErrorPage() *Middleware { + m := &Middleware{ + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { + if !ServeStaticErrorPageFile(w, r) { + next(w, r) + } + }, + } + m.modifyResponse = func(resp *Response) error { // only handles non-success status code and html/plain content type contentType := gphttp.GetContentType(resp.Header) if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) if ok { - errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode) + CustomErrorPage.Debug().Msgf("error page for status %d loaded", resp.StatusCode) /* trunk-ignore(golangci-lint/errcheck) */ io.Copy(io.Discard, resp.Body) // drain the original body resp.Body.Close() @@ -34,12 +41,13 @@ var CustomErrorPage = &Middleware{ resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage))) resp.Header.Set("Content-Type", "text/html; charset=utf-8") } else { - errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode) + CustomErrorPage.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } return nil } return nil - }, + } + return m } func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { @@ -51,7 +59,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { filename := path[len(gphttp.StaticFilePathPrefix):] file, ok := errorpage.GetStaticFile(filename) if !ok { - errPageLogger.Errorf("unable to load resource %s", filename) + logger.Error().Msg("unable to load resource " + filename) return false } ext := filepath.Ext(filename) @@ -63,15 +71,13 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { case ".css": w.Header().Set("Content-Type", "text/css; charset=utf-8") default: - errPageLogger.Errorf("unexpected file type %q for %s", ext, filename) + logger.Error().Msgf("unexpected file type %q for %s", ext, filename) } if _, err := w.Write(file); err != nil { - errPageLogger.WithError(err).Errorf("unable to write resource %s", filename) + logger.Err(err).Msg("unable to write resource " + filename) http.Error(w, "Error page failure", http.StatusInternalServerError) } return true } return false } - -var errPageLogger = logrus.WithField("middleware", "error_page") diff --git a/internal/api/v1/errorpage/error_page.go b/internal/net/http/middleware/errorpage/error_page.go similarity index 66% rename from internal/api/v1/errorpage/error_page.go rename to internal/net/http/middleware/errorpage/error_page.go index cb796bf..7f63d5d 100644 --- a/internal/api/v1/errorpage/error_page.go +++ b/internal/net/http/middleware/errorpage/error_page.go @@ -1,14 +1,15 @@ package errorpage import ( - "context" "fmt" "os" "path" "sync" - . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" W "github.com/yusing/go-proxy/internal/watcher" @@ -23,9 +24,10 @@ var ( ) var setup = sync.OnceFunc(func() { - dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath) + task := task.GlobalTask("error page") + dirWatcher = W.NewDirectoryWatcher(task.Subtask("dir watcher"), errPagesBasePath) loadContent() - go watchDir() + go watchDir(task) }) func GetStaticFile(filename string) ([]byte, bool) { @@ -44,7 +46,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) { func loadContent() { files, err := U.ListFiles(errPagesBasePath, 0) if err != nil { - Logger.Error(err) + logger.Err(err).Msg("failed to list error page resources") return } for _, file := range files { @@ -53,19 +55,21 @@ func loadContent() { } content, err := os.ReadFile(file) if err != nil { - Logger.Errorf("failed to read error page resource %s: %s", file, err) + logger.Warn().Err(err).Msgf("failed to read error page resource %s", file) continue } file = path.Base(file) - Logger.Infof("error page resource %s loaded", file) + logging.Info().Msgf("error page resource %s loaded", file) fileContentMap.Store(file, content) } } -func watchDir() { - eventCh, errCh := dirWatcher.Events(context.Background()) +func watchDir(task task.Task) { + eventCh, errCh := dirWatcher.Events(task.Context()) for { select { + case <-task.Context().Done(): + return case event, ok := <-eventCh: if !ok { return @@ -77,14 +81,14 @@ func watchDir() { loadContent() case events.ActionFileDeleted: fileContentMap.Delete(filename) - Logger.Infof("error page resource %s deleted", filename) + logger.Warn().Msgf("error page resource %s deleted", filename) case events.ActionFileRenamed: - Logger.Infof("error page resource %s deleted", filename) + logger.Warn().Msgf("error page resource %s deleted", filename) fileContentMap.Delete(filename) loadContent() } case err := <-errCh: - Logger.Errorf("error watching error page directory: %s", err) + E.LogError("error watching error page directory", err, &logger) } } } diff --git a/internal/net/http/middleware/errorpage/logger.go b/internal/net/http/middleware/errorpage/logger.go new file mode 100644 index 0000000..bc0fc30 --- /dev/null +++ b/internal/net/http/middleware/errorpage/logger.go @@ -0,0 +1,5 @@ +package errorpage + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "errorpage").Logger() diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 93e51d5..29ce004 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -24,11 +24,12 @@ type ( client http.Client } forwardAuthOpts struct { - Address string - TrustForwardHeader bool - AuthResponseHeaders []string - AddAuthCookiesToResponse []string - transport http.RoundTripper + Address string `json:"address"` + TrustForwardHeader bool `json:"trustForwardHeader"` + AuthResponseHeaders []string `json:"authResponseHeaders"` + AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"` + + transport http.RoundTripper } ) @@ -39,13 +40,11 @@ var ForwardAuth = &forwardAuth{ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) { fa := new(forwardAuth) fa.forwardAuthOpts = new(forwardAuthOpts) - err := Deserialize(optsRaw, fa.forwardAuthOpts) - if err != nil { + if err := Deserialize(optsRaw, fa.forwardAuthOpts); err != nil { return nil, err } - _, err = E.Check(url.Parse(fa.Address)) - if err != nil { - return nil, E.Invalid("address", fa.Address) + if _, err := url.Parse(fa.Address); err != nil { + return nil, E.From(err) } fa.m = &Middleware{ diff --git a/internal/net/http/middleware/logger.go b/internal/net/http/middleware/logger.go new file mode 100644 index 0000000..643f9b6 --- /dev/null +++ b/internal/net/http/middleware/logger.go @@ -0,0 +1,5 @@ +package middleware + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "middleware").Logger() diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 804450a..0b2b80e 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" U "github.com/yusing/go-proxy/internal/utils" @@ -32,6 +33,8 @@ type ( Middleware struct { _ U.NoCopy + zerolog.Logger + name string before BeforeFunc // runs before ReverseProxy.ServeHTTP @@ -78,13 +81,19 @@ func (m *Middleware) MarshalJSON() ([]byte, error) { } func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) { - if len(optsRaw) != 0 && m.withOptions != nil { - return m.withOptions(optsRaw) + if m.withOptions != nil { + m, err := m.withOptions(optsRaw) + if err != nil { + return nil, err + } + m.Logger = logger.With().Str("name", m.name).Logger() + return m, nil } // WithOptionsClone is called only once // set withOptions and labelParser will not be used after that return &Middleware{ + Logger: logger.With().Str("name", m.name).Logger(), name: m.name, before: m.before, modifyResponse: m.modifyResponse, @@ -108,24 +117,20 @@ func (m *Middleware) ModifyResponse(resp *Response) error { } // TODO: check conflict or duplicates. -func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.Error) { - middlewares = make([]*Middleware, 0, len(middlewaresMap)) +func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) { + middlewares := make([]*Middleware, 0, len(middlewaresMap)) - invalidM := E.NewBuilder("invalid middlewares") - invalidOpts := E.NewBuilder("invalid options") - defer func() { - invalidM.Add(invalidOpts.Build()) - invalidM.To(&res) - }() + errs := E.NewBuilder("middlewares compile error") + invalidOpts := E.NewBuilder("options compile error") for name, opts := range middlewaresMap { - m, ok := Get(name) - if !ok { - invalidM.Add(E.NotExist("middleware", name)) + m, err := Get(name) + if err != nil { + errs.Add(err) continue } - m, err := m.WithOptionsClone(opts) + m, err = m.WithOptionsClone(opts) if err != nil { invalidOpts.Add(err.Subject(name)) continue @@ -133,7 +138,10 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Mid middlewares = append(middlewares, m) } - return + if invalidOpts.HasError() { + errs.Add(invalidOpts.Error()) + } + return middlewares, errs.Error() } func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) { diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index 19bd92f..328cc64 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -4,64 +4,60 @@ import ( "fmt" "net/http" "os" + "path" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "gopkg.in/yaml.v3" ) -func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.Error) { +func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]*Middleware { fileContent, err := os.ReadFile(filePath) if err != nil { - return nil, E.FailWith("read middleware compose file", err) + eb.Add(err) + return nil } - return BuildMiddlewaresFromYAML(fileContent) + return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) } -func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.Error) { - b := E.NewBuilder("middlewares compile errors") - defer b.To(&outErr) - +func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware { var rawMap map[string][]map[string]any err := yaml.Unmarshal(data, &rawMap) if err != nil { - b.Add(E.FailWith("yaml unmarshal", err)) - return + eb.Add(err) + return nil } - middlewares = make(map[string]*Middleware) + middlewares := make(map[string]*Middleware) for name, defs := range rawMap { - chainErr := E.NewBuilder("%s", name) + chainErr := E.NewBuilder("") chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { if def["use"] == nil || def["use"] == "" { - chainErr.Add(E.Missing("use").Subjectf(".%d", i)) + chainErr.Addf("item %d: missing field 'use'", i) continue } baseName := def["use"].(string) - base, ok := Get(baseName) - if !ok { - base, ok = middlewares[baseName] - if !ok { - chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i)) - continue - } + base, err := Get(baseName) + if err != nil { + chainErr.Add(err.Subjectf("%s[%d]", name, i)) + continue } delete(def, "use") m, err := base.WithOptionsClone(def) if err != nil { - chainErr.Add(err.Subjectf("item%d", i)) + chainErr.Add(err.Subjectf("%s[%d]", name, i)) continue } m.name = fmt.Sprintf("%s[%d]", name, i) chain = append(chain, m) } if chainErr.HasError() { - b.Add(chainErr.Build()) + eb.Add(chainErr.Error().Subject(source)) } else { middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain) } } - return + return middlewares } // TODO: check conflict or duplicates. @@ -86,11 +82,13 @@ func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware { } if len(modResps) > 0 { m.modifyResponse = func(res *Response) error { - b := E.NewBuilder("errors in middleware") + errs := E.NewBuilder("modify response errors") for _, mr := range modResps { - b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name)) + if err := mr.modifyResponse(res); err != nil { + errs.Add(E.From(err).Subject(mr.name)) + } } - return b.Build().Error() + return errs.Error() } } diff --git a/internal/net/http/middleware/middleware_builder_test.go b/internal/net/http/middleware/middleware_builder_test.go index d7fca0c..914655a 100644 --- a/internal/net/http/middleware/middleware_builder_test.go +++ b/internal/net/http/middleware/middleware_builder_test.go @@ -13,10 +13,10 @@ import ( var testMiddlewareCompose []byte func TestBuild(t *testing.T) { - middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) - ExpectNoError(t, err.Error()) - _, err = E.Check(json.MarshalIndent(middlewares, "", " ")) - ExpectNoError(t, err.Error()) + errs := E.NewBuilder("") + middlewares := BuildMiddlewaresFromYAML("", testMiddlewareCompose, errs) + ExpectNoError(t, errs.Error()) + E.Must(json.MarshalIndent(middlewares, "", " ")) // t.Log(string(data)) // TODO: test } diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 26e67a7..4ff3699 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -6,26 +6,37 @@ import ( "path" "strings" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) -var middlewares map[string]*Middleware +var allMiddlewares map[string]*Middleware -func Get(name string) (middleware *Middleware, ok bool) { - middleware, ok = middlewares[U.ToLowerNoSnake(name)] - return +var ( + ErrUnknownMiddleware = E.New("unknown middleware") + ErrDuplicatedMiddleware = E.New("duplicated middleware") +) + +func Get(name string) (*Middleware, Error) { + middleware, ok := allMiddlewares[U.ToLowerNoSnake(name)] + if !ok { + return nil, ErrUnknownMiddleware. + Subject(name). + Withf(strutils.DoYouMean(utils.NearestField(name, allMiddlewares))) + } + return middleware, nil } func All() map[string]*Middleware { - return middlewares + return allMiddlewares } // initialize middleware names and label parsers func init() { - middlewares = map[string]*Middleware{ + allMiddlewares = map[string]*Middleware{ "setxforwarded": SetXForwarded, "hidexforwarded": HideXForwarded, "redirecthttp": RedirectHTTP, @@ -39,10 +50,10 @@ func init() { // !experimental "forwardauth": ForwardAuth.m, - "oauth2": OAuth2.m, + // "oauth2": OAuth2.m, } names := make(map[*Middleware][]string) - for name, m := range middlewares { + for name, m := range allMiddlewares { names[m] = append(names[m], http.CanonicalHeaderKey(name)) } for m, names := range names { @@ -55,27 +66,30 @@ func init() { } func LoadComposeFiles() { - b := E.NewBuilder("failed to load middlewares") + errs := E.NewBuilder("middleware compile errors") middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0) if err != nil { - logrus.Errorf("failed to list middleware definitions: %s", err) + logger.Err(err).Msg("failed to list middleware definitions") return } for _, defFile := range middlewareDefs { - mws, err := BuildMiddlewaresFromComposeFile(defFile) + mws := BuildMiddlewaresFromComposeFile(defFile, errs) + if len(mws) == 0 { + continue + } for name, m := range mws { - if _, ok := middlewares[name]; ok { - b.Add(E.Duplicated("middleware", name)) + if _, ok := allMiddlewares[name]; ok { + errs.Add(ErrDuplicatedMiddleware.Subject(name)) continue } - middlewares[U.ToLowerNoSnake(name)] = m - logger.Infof("middleware %s loaded from %s", name, path.Base(defFile)) + allMiddlewares[U.ToLowerNoSnake(name)] = m + logger.Info(). + Str("name", name). + Str("src", path.Base(defFile)). + Msg("middleware loaded") } - b.Add(err.Subject(path.Base(defFile))) } - if b.HasError() { - logger.Error(b.Build()) + if errs.HasError() { + E.LogError(errs.About(), errs.Error(), &logger) } } - -var logger = logrus.WithField("module", "middlewares") diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 2b9ca2e..0b0ce60 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -12,9 +12,9 @@ type ( } // order: set_headers -> add_headers -> hide_headers modifyRequestOpts struct { - SetHeaders map[string]string - AddHeaders map[string]string - HideHeaders []string + SetHeaders map[string]string `json:"setHeaders"` + AddHeaders map[string]string `json:"addHeaders"` + HideHeaders []string `json:"hideHeaders"` } ) diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 1590e11..6d9d9ec 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -16,7 +16,7 @@ func TestSetModifyRequest(t *testing.T) { t.Run("set_options", func(t *testing.T) { mr, err := ModifyRequest.m.WithOptionsClone(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).HideHeaders, opts["hide_headers"].([]string)) @@ -26,7 +26,7 @@ func TestSetModifyRequest(t *testing.T) { result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ middlewareOpt: opts, }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index 62011d8..4edd710 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -13,11 +13,7 @@ type ( m *Middleware } // order: set_headers -> add_headers -> hide_headers - modifyResponseOpts struct { - SetHeaders map[string]string - AddHeaders map[string]string - HideHeaders []string - } + modifyResponseOpts = modifyRequestOpts ) var ModifyResponse = &modifyResponse{ diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 65e98d6..370e590 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -16,7 +16,7 @@ func TestSetModifyResponse(t *testing.T) { t.Run("set_options", func(t *testing.T) { mr, err := ModifyResponse.m.WithOptionsClone(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) @@ -26,7 +26,7 @@ func TestSetModifyResponse(t *testing.T) { result, err := newMiddlewareTest(ModifyResponse.m, &testArgs{ middlewareOpt: opts, }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0") t.Log(result.ResponseHeaders.Get("Accept-Encoding")) ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) diff --git a/internal/net/http/middleware/oauth2.go b/internal/net/http/middleware/oauth2.go index b500bee..10ea6d6 100644 --- a/internal/net/http/middleware/oauth2.go +++ b/internal/net/http/middleware/oauth2.go @@ -1,129 +1,117 @@ package middleware -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "reflect" +// import ( +// "encoding/json" +// "fmt" +// "net/http" +// "net/url" - E "github.com/yusing/go-proxy/internal/error" -) +// E "github.com/yusing/go-proxy/internal/error" +// ) -type oAuth2 struct { - *oAuth2Opts - m *Middleware -} +// type oAuth2 struct { +// oAuth2Opts +// m *Middleware +// } -type oAuth2Opts struct { - ClientID string - ClientSecret string - AuthURL string // Authorization Endpoint - TokenURL string // Token Endpoint -} +// type oAuth2Opts struct { +// ClientID string `validate:"required"` +// ClientSecret string `validate:"required"` +// AuthURL string `validate:"required"` // Authorization Endpoint +// TokenURL string `validate:"required"` // Token Endpoint +// } -var OAuth2 = &oAuth2{ - m: &Middleware{withOptions: NewAuthentikOAuth2}, -} +// var OAuth2 = &oAuth2{ +// m: &Middleware{withOptions: NewAuthentikOAuth2}, +// } -func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { - oauth := new(oAuth2) - oauth.m = &Middleware{ - impl: oauth, - before: oauth.handleOAuth2, - } - oauth.oAuth2Opts = &oAuth2Opts{} - err := Deserialize(opts, oauth.oAuth2Opts) - if err != nil { - return nil, err - } - b := E.NewBuilder("missing required fields") - optV := reflect.ValueOf(oauth.oAuth2Opts) - for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) { - if optV.FieldByName(field.Name).Len() == 0 { - b.Add(E.Missing(field.Name)) - } - } - if b.HasError() { - return nil, b.Build().Subject("oAuth2") - } - return oauth.m, nil -} +// func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.Error) { +// oauth := new(oAuth2) +// oauth.m = &Middleware{ +// impl: oauth, +// before: oauth.handleOAuth2, +// } +// err := Deserialize(opts, &oauth.oAuth2Opts) +// if err != nil { +// return nil, err +// } +// return oauth.m, nil +// } -func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { - // Check if the user is authenticated (you may use session, cookie, etc.) - if !userIsAuthenticated(r) { - // TODO: Redirect to OAuth2 auth URL - http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", - oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) - return - } +// func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) { +// // Check if the user is authenticated (you may use session, cookie, etc.) +// if !userIsAuthenticated(r) { +// // TODO: Redirect to OAuth2 auth URL +// http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code", +// oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound) +// return +// } - // If you have a token in the query string, process it - if code := r.URL.Query().Get("code"); code != "" { - // Exchange the authorization code for a token here - // Use the TokenURL and authenticate the user - token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI) - if err != nil { - // handle error - http.Error(rw, "failed to get token", http.StatusUnauthorized) - return - } +// // If you have a token in the query string, process it +// if code := r.URL.Query().Get("code"); code != "" { +// // Exchange the authorization code for a token here +// // Use the TokenURL and authenticate the user +// token, err := exchangeCodeForToken(code, &oauth.oAuth2Opts, r.RequestURI) +// if err != nil { +// // handle error +// http.Error(rw, "failed to get token", http.StatusUnauthorized) +// return +// } - // Save token and user info based on your requirements - saveToken(rw, token) +// // Save token and user info based on your requirements +// saveToken(rw, token) - // Redirect to the originally requested URL - http.Redirect(rw, r, "/", http.StatusFound) - return - } +// // Redirect to the originally requested URL +// http.Redirect(rw, r, "/", http.StatusFound) +// return +// } - // If user is authenticated, go to the next handler - next(rw, r) -} +// // If user is authenticated, go to the next handler +// next(rw, r) +// } -func userIsAuthenticated(r *http.Request) bool { - // Example: Check for a session or cookie - session, err := r.Cookie("session_token") - if err != nil || session.Value == "" { - return false - } - // Validate the session_token if necessary - return true -} +// func userIsAuthenticated(r *http.Request) bool { +// // Example: Check for a session or cookie +// session, err := r.Cookie("session_token") +// if err != nil || session.Value == "" { +// return false +// } +// // Validate the session_token if necessary +// return true +// } -func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { - // Prepare the request body - data := url.Values{ - "client_id": {opts.ClientID}, - "client_secret": {opts.ClientSecret}, - "code": {code}, - "grant_type": {"authorization_code"}, - "redirect_uri": {requestURI}, - } - resp, err := http.PostForm(opts.TokenURL, data) - if err != nil { - return "", fmt.Errorf("failed to request token: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) - } - // Decode the response - var tokenResp struct { - AccessToken string `json:"access_token"` - } - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", fmt.Errorf("failed to decode token response: %w", err) - } - return tokenResp.AccessToken, nil -} +// func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { +// // Prepare the request body +// data := url.Values{ +// "client_id": {opts.ClientID}, +// "client_secret": {opts.ClientSecret}, +// "code": {code}, +// "grant_type": {"authorization_code"}, +// "redirect_uri": {requestURI}, +// } +// resp, err := http.PostForm(opts.TokenURL, data) +// if err != nil { +// return "", fmt.Errorf("failed to request token: %w", err) +// } +// defer resp.Body.Close() +// if resp.StatusCode != http.StatusOK { +// return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status) +// } +// // Decode the response +// var tokenResp struct { +// AccessToken string `json:"access_token"` +// } +// if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { +// return "", fmt.Errorf("failed to decode token response: %w", err) +// } +// return tokenResp.AccessToken, nil +// } -func saveToken(rw ResponseWriter, token string) { - // Example: Save token in cookie - http.SetCookie(rw, &http.Cookie{ - Name: "auth_token", - Value: token, - // set other properties as necessary, such as Secure and HttpOnly - }) -} +// func saveToken(rw ResponseWriter, token string) { +// // Example: Save token in cookie +// http.SetCookie(rw, &http.Cookie{ +// Name: "auth_token", +// Value: token, +// // set other properties as necessary, such as Secure and HttpOnly +// }) +// } diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 12d674c..0bbf452 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -16,9 +16,9 @@ type realIP struct { type realIPOpts struct { // Header is the name of the header to use for the real client IP - Header string + Header string `json:"header"` // From is a list of Address / CIDRs to trust - From []*types.CIDR + From []*types.CIDR `json:"from"` /* If recursive search is disabled, the original client address that matches one of the trusted addresses is replaced by @@ -27,7 +27,7 @@ type realIPOpts struct { the original client address that matches one of the trusted addresses is replaced by the last non-trusted address sent in the request header field. */ - Recursive bool + Recursive bool `json:"recursive"` } var RealIP = &realIP{ diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index f47c272..67457b8 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -40,7 +40,7 @@ func TestSetRealIPOpts(t *testing.T) { } ri, err := NewRealIP(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) for i, CIDR := range ri.impl.(*realIP).From { @@ -61,15 +61,15 @@ func TestSetRealIP(t *testing.T) { "set_headers": map[string]string{testHeader: testRealIP}, } realip, err := NewRealIP(opts) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) mr, err := NewModifyRequest(optsMr) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip}) result, err := newMiddlewareTest(mid, nil) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) t.Log(traces) ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) diff --git a/internal/net/http/middleware/redirect_http_test.go b/internal/net/http/middleware/redirect_http_test.go index 6d2b2f6..b591fc2 100644 --- a/internal/net/http/middleware/redirect_http_test.go +++ b/internal/net/http/middleware/redirect_http_test.go @@ -12,7 +12,7 @@ func TestRedirectToHTTPs(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ scheme: "http", }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect) ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort) } @@ -21,6 +21,6 @@ func TestNoRedirect(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ scheme: "https", }) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusOK) } diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index 8c169e2..2b444f3 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -6,7 +6,7 @@ import ( "time" gphttp "github.com/yusing/go-proxy/internal/net/http" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type Trace struct { @@ -88,7 +88,7 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace { return nil } return addTrace(&Trace{ - Time: U.FormatTime(time.Now()), + Time: strutils.FormatTime(time.Now()), Caller: m.Fullname(), Message: fmt.Sprintf(msg, args...), }) diff --git a/internal/net/http/modify_response_writer.go b/internal/net/http/modify_response_writer.go index 8ba0d72..a7a82c2 100644 --- a/internal/net/http/modify_response_writer.go +++ b/internal/net/http/modify_response_writer.go @@ -57,8 +57,7 @@ func (w *ModifyResponseWriter) WriteHeader(code int) { } if err := w.modifier(&resp); err != nil { - w.modifierErr = err - logger.Errorf("error modifying response: %s", err) + w.modifierErr = fmt.Errorf("response modifier error: %w", err) w.w.WriteHeader(http.StatusInternalServerError) return } diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index b3afc79..b4d4448 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -23,7 +23,7 @@ import ( "strings" "sync" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" "golang.org/x/net/http/httpguts" @@ -69,6 +69,8 @@ type ProxyRequest struct { // 1xx responses are forwarded to the client if the underlying // transport supports ClientTrace.Got1xxResponse. type ReverseProxy struct { + zerolog.Logger + // The transport used to perform proxy requests. // If nil, http.DefaultTransport is used. Transport http.RoundTripper @@ -149,7 +151,12 @@ func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) if transport == nil { panic("nil transport") } - rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target} + rp := &ReverseProxy{ + Logger: logger.With().Str("name", name).Logger(), + Transport: transport, + TargetName: name, + TargetURL: target, + } rp.ServeHTTP = rp.serveHTTP return rp } @@ -195,9 +202,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err switch { case errors.Is(err, context.Canceled), errors.Is(err, io.EOF): - logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) + logger.Debug().Err(err).Str("url", r.URL.String()).Msg("http proxy error") default: - logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) + logger.Err(err).Str("url", r.URL.String()).Msg("http proxy error") } if writeHeader { rw.WriteHeader(http.StatusBadGateway) @@ -219,6 +226,10 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response } func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { + if _, ok := rw.(DummyResponseWriter); ok { + return + } + transport := p.Transport ctx := req.Context() @@ -453,6 +464,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R resUpType := UpgradeType(res.Header) if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) + return } if !strings.EqualFold(reqUpType, resUpType) { p.errorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType), true) @@ -518,5 +530,3 @@ func IsPrint(s string) bool { } return true } - -var logger = logrus.WithField("module", "http") diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index 230ca16..65fc269 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -9,10 +9,15 @@ import ( type CIDR net.IPNet +var ( + ErrInvalidCIDR = E.New("invalid CIDR") + ErrInvalidCIDRType = E.New("invalid CIDR type") +) + func (cidr *CIDR) ConvertFrom(val any) E.Error { cidrStr, ok := val.(string) if !ok { - return E.TypeMismatch[string](val) + return ErrInvalidCIDRType.Subjectf("%T", val) } if !strings.Contains(cidrStr, "/") { @@ -20,7 +25,7 @@ func (cidr *CIDR) ConvertFrom(val any) E.Error { } _, ipnet, err := net.ParseCIDR(cidrStr) if err != nil { - return E.Invalid("CIDR", cidr) + return ErrInvalidCIDR.Subject(cidrStr) } *cidr = CIDR(*ipnet) return nil diff --git a/internal/net/types/stream.go b/internal/net/types/stream.go index 871521f..2892a5f 100644 --- a/internal/net/types/stream.go +++ b/internal/net/types/stream.go @@ -5,9 +5,28 @@ import ( "net" ) -type Stream interface { - fmt.Stringer - net.Listener - Setup() error - Handle(conn net.Conn) error +type ( + Stream interface { + fmt.Stringer + StreamListener + Setup() error + Handle(conn StreamConn) error + } + StreamListener interface { + Addr() net.Addr + Accept() (StreamConn, error) + Close() error + } + StreamConn any + NetListenerWrapper struct { + net.Listener + } +) + +func NetListener(l net.Listener) StreamListener { + return NetListenerWrapper{Listener: l} +} + +func (l NetListenerWrapper) Accept() (StreamConn, error) { + return l.Listener.Accept() } diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index 7d9d0e4..2b74f42 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -1,22 +1,32 @@ package notif import ( - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ( Dispatcher struct { task task.Task - logCh chan *logrus.Entry + logCh chan *LogMessage providers F.Set[Provider] } + LogMessage struct { + Level zerolog.Level + Title, Message string + } ) var dispatcher *Dispatcher +var ErrUnknownNotifProvider = E.New("unknown notification provider") + +const dispatchErr = "notification dispatch error" + func init() { dispatcher = newNotifDispatcher() go dispatcher.start() @@ -25,7 +35,7 @@ func init() { func newNotifDispatcher() *Dispatcher { return &Dispatcher{ task: task.GlobalTask("notif dispatcher"), - logCh: make(chan *logrus.Entry), + logCh: make(chan *LogMessage), providers: F.NewSet[Provider](), } } @@ -34,11 +44,13 @@ func GetDispatcher() *Dispatcher { return dispatcher } -func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E.Error) { +func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, error) { name := configSubTask.Name() createFunc, ok := Providers[name] if !ok { - return nil, E.NotExist("provider", name) + return nil, ErrUnknownNotifProvider. + Subject(name). + Withf(strutils.DoYouMean(utils.NearestField(name, Providers))) } if provider, err := createFunc(cfg); err != nil { return nil, err @@ -53,7 +65,6 @@ func RegisterProvider(configSubTask task.Task, cfg ProviderConfig) (Provider, E. func (disp *Dispatcher) start() { defer dispatcher.task.Finish("dispatcher stopped") - defer close(dispatcher.logCh) for { select { @@ -65,36 +76,39 @@ func (disp *Dispatcher) start() { } } -func (disp *Dispatcher) dispatch(entry *logrus.Entry) { +func (disp *Dispatcher) dispatch(msg *LogMessage) { task := disp.task.Subtask("dispatch notif") - defer task.Finish("notifs dispatched") + defer task.Finish("notif dispatched") - errs := E.NewBuilder("errors sending notif") + errs := E.NewBuilder(dispatchErr) disp.providers.RangeAllParallel(func(p Provider) { - if err := p.Send(task.Context(), entry); err != nil { - errs.Addf("%s: %s", p.Name(), err) + if err := p.Send(task.Context(), msg); err != nil { + errs.Add(E.PrependSubject(p.Name(), err)) } }) - if err := errs.Build(); err != nil { - logrus.Error("notif dispatcher failure: ", err) + if errs.HasError() { + E.LogError(errs.About(), errs.Error()) } } -// Levels implements logrus.Hook. -func (disp *Dispatcher) Levels() []logrus.Level { - return []logrus.Level{ - logrus.WarnLevel, - logrus.ErrorLevel, - logrus.FatalLevel, - logrus.PanicLevel, - } -} +// Run implements zerolog.Hook. +// func (disp *Dispatcher) Run(e *zerolog.Event, level zerolog.Level, message string) { +// if strings.HasPrefix(message, dispatchErr) { // prevent recursion +// return +// } +// switch level { +// case zerolog.WarnLevel, zerolog.ErrorLevel, zerolog.FatalLevel, zerolog.PanicLevel: +// disp.logCh <- &LogMessage{ +// Level: level, +// Message: message, +// } +// } +// } -// Fire implements logrus.Hook. -func (disp *Dispatcher) Fire(entry *logrus.Entry) error { - if disp.providers.Size() == 0 { - return nil +func Notify(title, msg string) { + dispatcher.logCh <- &LogMessage{ + Level: zerolog.InfoLevel, + Title: title, + Message: msg, } - disp.logCh <- entry - return nil } diff --git a/internal/notif/gotify.go b/internal/notif/gotify.go index ff6be73..90c9e01 100644 --- a/internal/notif/gotify.go +++ b/internal/notif/gotify.go @@ -9,7 +9,7 @@ import ( "net/url" "github.com/gotify/server/v2/model" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" U "github.com/yusing/go-proxy/internal/utils" ) @@ -39,7 +39,7 @@ func newGotifyClient(cfg map[string]any) (Provider, E.Error) { url, uErr := url.Parse(client.URL) if uErr != nil { - return nil, E.FailWith("parse url", uErr) + return nil, E.Errorf("invalid gotify URL %s", client.URL) } client.url = url @@ -52,30 +52,23 @@ func (client *GotifyClient) Name() string { } // Send implements NotifProvider. -func (client *GotifyClient) Send(ctx context.Context, entry *logrus.Entry) error { +func (client *GotifyClient) Send(ctx context.Context, logMsg *LogMessage) error { var priority int - var title string - switch entry.Level { - case logrus.WarnLevel: + switch logMsg.Level { + case zerolog.WarnLevel: priority = 2 - title = "Warning" - case logrus.ErrorLevel: + case zerolog.ErrorLevel: priority = 5 - title = "Error" - case logrus.FatalLevel, logrus.PanicLevel: + case zerolog.FatalLevel, zerolog.PanicLevel: priority = 8 - title = "Critical" default: return nil } - if subjects := FieldsAsTitle(entry); subjects != "" { - title = subjects + " " + title - } msg := &GotifyMessage{ - Title: title, - Message: entry.Message, + Title: logMsg.Title, + Message: logMsg.Message, Priority: priority, } diff --git a/internal/notif/logrus.go b/internal/notif/logrus.go deleted file mode 100644 index 6ca0bf1..0000000 --- a/internal/notif/logrus.go +++ /dev/null @@ -1,21 +0,0 @@ -package notif - -import ( - "fmt" - "strings" - - "github.com/sirupsen/logrus" - U "github.com/yusing/go-proxy/internal/utils" -) - -func FieldsAsTitle(entry *logrus.Entry) string { - if len(entry.Data) == 0 { - return "" - } - var parts []string - for k, v := range entry.Data { - parts = append(parts, fmt.Sprintf("%s: %s", k, v)) - } - parts[0] = U.Title(parts[0]) - return strings.Join(parts, ", ") -} diff --git a/internal/notif/providers.go b/internal/notif/providers.go index 33b3dc5..5ab6e9a 100644 --- a/internal/notif/providers.go +++ b/internal/notif/providers.go @@ -3,14 +3,13 @@ package notif import ( "context" - "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" ) type ( Provider interface { Name() string - Send(ctx context.Context, entry *logrus.Entry) error + Send(ctx context.Context, logMsg *LogMessage) error } ProviderCreateFunc func(map[string]any) (Provider, E.Error) ProviderConfig map[string]any diff --git a/internal/proxy/entry/entry.go b/internal/proxy/entry/entry.go index 0c17ae8..4eb71f3 100644 --- a/internal/proxy/entry/entry.go +++ b/internal/proxy/entry/entry.go @@ -23,18 +23,18 @@ func ValidateEntry(m *RawEntry) (Entry, E.Error) { scheme, err := T.NewScheme(m.Scheme) if err != nil { - return nil, err + return nil, E.From(err) } var entry Entry - e := E.NewBuilder("error validating entry") + errs := E.NewBuilder("entry validation failed") if scheme.IsStream() { - entry = validateStreamEntry(m, e) + entry = validateStreamEntry(m, errs) } else { - entry = validateRPEntry(m, scheme, e) + entry = validateRPEntry(m, scheme, errs) } - if err := e.Build(); err != nil { - return nil, err + if errs.HasError() { + return nil, errs.Error() } return entry, nil } diff --git a/internal/proxy/entry/raw.go b/internal/proxy/entry/raw.go index 6135618..1ebe78d 100644 --- a/internal/proxy/entry/raw.go +++ b/internal/proxy/entry/raw.go @@ -5,13 +5,14 @@ import ( "strings" "github.com/docker/docker/api/types" - "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/watcher/health" ) @@ -85,20 +86,20 @@ func (e *RawEntry) FillMissingFields() { } else if !isDocker { pp = "80" } else { - logrus.Debugf("no port found for %s", e.Alias) + logging.Debug().Msg("no port found for " + e.Alias) } } // replace private port with public port if using public IP. if e.Host == cont.PublicIP { if p, ok := cont.PrivatePortMapping[pp]; ok { - pp = U.PortString(p.PublicPort) + pp = strutils.PortString(p.PublicPort) } } // replace public port with private port if using private IP. if e.Host == cont.PrivateIP { if p, ok := cont.PublicPortMapping[pp]; ok { - pp = U.PortString(p.PrivatePort) + pp = strutils.PortString(p.PrivatePort) } } diff --git a/internal/proxy/entry/reverse_proxy.go b/internal/proxy/entry/reverse_proxy.go index 8d48652..976dde3 100644 --- a/internal/proxy/entry/reverse_proxy.go +++ b/internal/proxy/entry/reverse_proxy.go @@ -53,7 +53,7 @@ func (rp *ReverseProxyEntry) IdlewatcherConfig() *idlewatcher.Config { return rp.Idlewatcher } -func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEntry { +func validateRPEntry(m *RawEntry, s fields.Scheme, errs *E.Builder) *ReverseProxyEntry { cont := m.Container if cont == nil { cont = docker.DummyContainer @@ -64,35 +64,26 @@ func validateRPEntry(m *RawEntry, s fields.Scheme, b E.Builder) *ReverseProxyEnt lb = nil } - host, err := fields.ValidateHost(m.Host) - b.Add(err) + host := E.Collect(errs, fields.ValidateHost, m.Host) + port := E.Collect(errs, fields.ValidatePort, m.Port) + pathPats := E.Collect(errs, fields.ValidatePathPatterns, m.PathPatterns) + url := E.Collect(errs, url.Parse, fmt.Sprintf("%s://%s:%d", s, host, port)) + iwCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container) - port, err := fields.ValidatePort(m.Port) - b.Add(err) - - pathPatterns, err := fields.ValidatePathPatterns(m.PathPatterns) - b.Add(err) - - url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) - b.Add(err) - - idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container) - b.Add(err) - - if err != nil { + if errs.HasError() { return nil } return &ReverseProxyEntry{ Raw: m, - Alias: fields.NewAlias(m.Alias), + Alias: fields.Alias(m.Alias), Scheme: s, URL: net.NewURL(url), NoTLSVerify: m.NoTLSVerify, - PathPatterns: pathPatterns, + PathPatterns: pathPats, HealthCheck: m.HealthCheck, LoadBalance: lb, Middlewares: m.Middlewares, - Idlewatcher: idleWatcherCfg, + Idlewatcher: iwCfg, } } diff --git a/internal/proxy/entry/stream.go b/internal/proxy/entry/stream.go index 4416b20..5ffabe8 100644 --- a/internal/proxy/entry/stream.go +++ b/internal/proxy/entry/stream.go @@ -51,34 +51,25 @@ func (s *StreamEntry) IdlewatcherConfig() *idlewatcher.Config { return s.Idlewatcher } -func validateStreamEntry(m *RawEntry, b E.Builder) *StreamEntry { +func validateStreamEntry(m *RawEntry, errs *E.Builder) *StreamEntry { cont := m.Container if cont == nil { cont = docker.DummyContainer } - host, err := fields.ValidateHost(m.Host) - b.Add(err) + host := E.Collect(errs, fields.ValidateHost, m.Host) + port := E.Collect(errs, fields.ValidateStreamPort, m.Port) + scheme := E.Collect(errs, fields.ValidateStreamScheme, m.Scheme) + url := E.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", scheme.ListeningScheme, host, port.ListeningPort)) + idleWatcherCfg := E.Collect(errs, idlewatcher.ValidateConfig, m.Container) - port, err := fields.ValidateStreamPort(m.Port) - b.Add(err) - - scheme, err := fields.ValidateStreamScheme(m.Scheme) - b.Add(err) - - url, err := E.Check(net.ParseURL(fmt.Sprintf("%s://%s:%d", scheme.ProxyScheme, m.Host, port.ProxyPort))) - b.Add(err) - - idleWatcherCfg, err := idlewatcher.ValidateConfig(m.Container) - b.Add(err) - - if b.HasError() { + if errs.HasError() { return nil } return &StreamEntry{ Raw: m, - Alias: fields.NewAlias(m.Alias), + Alias: fields.Alias(m.Alias), Scheme: *scheme, URL: url, Host: host, diff --git a/internal/proxy/fields/alias.go b/internal/proxy/fields/alias.go index 289f964..07a91eb 100644 --- a/internal/proxy/fields/alias.go +++ b/internal/proxy/fields/alias.go @@ -1,6 +1,3 @@ package fields -type ( - Alias string - NewAlias = Alias -) +type Alias string diff --git a/internal/proxy/fields/host.go b/internal/proxy/fields/host.go index 68c17c2..892e72c 100644 --- a/internal/proxy/fields/host.go +++ b/internal/proxy/fields/host.go @@ -1,14 +1,10 @@ package fields -import ( - E "github.com/yusing/go-proxy/internal/error" -) - type ( Host string Subdomain = Alias ) -func ValidateHost[String ~string](s String) (Host, E.Error) { +func ValidateHost[String ~string](s String) (Host, error) { return Host(s), nil } diff --git a/internal/proxy/fields/path_pattern.go b/internal/proxy/fields/path_pattern.go index 8d9abd3..29b7cf4 100644 --- a/internal/proxy/fields/path_pattern.go +++ b/internal/proxy/fields/path_pattern.go @@ -1,6 +1,8 @@ package fields import ( + "errors" + "fmt" "regexp" E "github.com/yusing/go-proxy/internal/error" @@ -13,12 +15,17 @@ type ( var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) -func ValidatePathPattern(s string) (PathPattern, E.Error) { +var ( + ErrEmptyPathPattern = errors.New("path must not be empty") + ErrInvalidPathPattern = errors.New("invalid path pattern") +) + +func ValidatePathPattern(s string) (PathPattern, error) { if len(s) == 0 { - return "", E.Invalid("path", "must not be empty") + return "", ErrEmptyPathPattern } if !pathPattern.MatchString(s) { - return "", E.Invalid("path pattern", s) + return "", fmt.Errorf("%w %q", ErrInvalidPathPattern, s) } return PathPattern(s), nil } @@ -27,13 +34,15 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.Error) { if len(s) == 0 { return []PathPattern{"/"}, nil } + errs := E.NewBuilder("invalid path patterns") pp := make(PathPatterns, len(s)) for i, v := range s { pattern, err := ValidatePathPattern(v) if err != nil { - return nil, err + errs.Add(err) + } else { + pp[i] = pattern } - pp[i] = pattern } - return pp, nil + return pp, errs.Error() } diff --git a/internal/proxy/fields/path_pattern_test.go b/internal/proxy/fields/path_pattern_test.go index d19cb97..261ee3f 100644 --- a/internal/proxy/fields/path_pattern_test.go +++ b/internal/proxy/fields/path_pattern_test.go @@ -1,9 +1,9 @@ package fields import ( + "errors" "testing" - E "github.com/yusing/go-proxy/internal/error" U "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -38,10 +38,10 @@ var invalidPatterns = []string{ func TestPathPatternRegex(t *testing.T) { for _, pattern := range validPatterns { _, err := ValidatePathPattern(pattern) - U.ExpectNoError(t, err.Error()) + U.ExpectNoError(t, err) } for _, pattern := range invalidPatterns { _, err := ValidatePathPattern(pattern) - U.ExpectError2(t, pattern, E.ErrInvalid, err.Error()) + U.ExpectTrue(t, errors.Is(err, ErrInvalidPathPattern)) } } diff --git a/internal/proxy/fields/port.go b/internal/proxy/fields/port.go index 5780005..9b809fd 100644 --- a/internal/proxy/fields/port.go +++ b/internal/proxy/fields/port.go @@ -4,22 +4,25 @@ import ( "strconv" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type Port int -func ValidatePort[String ~string](v String) (Port, E.Error) { - p, err := strconv.Atoi(string(v)) +var ErrPortOutOfRange = E.New("port out of range") + +func ValidatePort[String ~string](v String) (Port, error) { + p, err := strutils.Atoi(string(v)) if err != nil { - return ErrPort, E.Invalid("port number", v).With(err) + return ErrPort, err } return ValidatePortInt(p) } -func ValidatePortInt[Int int | uint16](v Int) (Port, E.Error) { +func ValidatePortInt[Int int | uint16](v Int) (Port, error) { p := Port(v) if !p.inBound() { - return ErrPort, E.OutOfRange("port", p) + return ErrPort, ErrPortOutOfRange.Subject(strconv.Itoa(int(p))) } return p, nil } diff --git a/internal/proxy/fields/scheme.go b/internal/proxy/fields/scheme.go index 2e4f6e5..b006464 100644 --- a/internal/proxy/fields/scheme.go +++ b/internal/proxy/fields/scheme.go @@ -6,12 +6,14 @@ import ( type Scheme string -func NewScheme[String ~string](s String) (Scheme, E.Error) { +var ErrInvalidScheme = E.New("invalid scheme") + +func NewScheme(s string) (Scheme, error) { switch s { case "http", "https", "tcp", "udp": return Scheme(s), nil } - return "", E.Invalid("scheme", s) + return "", ErrInvalidScheme.Subject(s) } func (s Scheme) IsHTTP() bool { return s == "http" } diff --git a/internal/proxy/fields/stream_port.go b/internal/proxy/fields/stream_port.go index 020d455..6cbc767 100644 --- a/internal/proxy/fields/stream_port.go +++ b/internal/proxy/fields/stream_port.go @@ -3,7 +3,6 @@ package fields import ( "strings" - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" ) @@ -12,7 +11,9 @@ type StreamPort struct { ProxyPort Port `json:"proxy"` } -func ValidateStreamPort(p string) (_ StreamPort, err E.Error) { +var ErrStreamPortTooManyColons = E.New("too many colons") + +func ValidateStreamPort(p string) (StreamPort, error) { split := strings.Split(p, ":") switch len(split) { @@ -21,36 +22,14 @@ func ValidateStreamPort(p string) (_ StreamPort, err E.Error) { case 2: break default: - err = E.Invalid("stream port", p).With("too many colons") - return + return StreamPort{}, ErrStreamPortTooManyColons.Subject(p) } - listeningPort, err := ValidatePort(split[0]) - if err != nil { - err = err.Subject("listening port") - return - } - - proxyPort, err := ValidatePort(split[1]) - - if err.Is(E.ErrOutOfRange) { - err = err.Subject("proxy port") - return - } else if err != nil { - proxyPort, err = parseNameToPort(split[1]) - if err != nil { - err = E.Invalid("proxy port", proxyPort) - return - } + listeningPort, lErr := ValidatePort(split[0]) + proxyPort, pErr := ValidatePort(split[1]) + if err := E.Join(lErr, pErr); err != nil { + return StreamPort{}, err } return StreamPort{listeningPort, proxyPort}, nil } - -func parseNameToPort(name string) (Port, E.Error) { - port, ok := common.ServiceNamePortMapTCP[name] - if !ok { - return ErrPort, E.Invalid("service", name) - } - return Port(port), nil -} diff --git a/internal/proxy/fields/stream_port_test.go b/internal/proxy/fields/stream_port_test.go index fce707c..a18732a 100644 --- a/internal/proxy/fields/stream_port_test.go +++ b/internal/proxy/fields/stream_port_test.go @@ -1,9 +1,9 @@ package fields import ( + "strconv" "testing" - E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -11,7 +11,6 @@ var validPorts = []string{ "1234:5678", "0:2345", "2345", - "1234:postgres", } var invalidPorts = []string{ @@ -19,7 +18,6 @@ var invalidPorts = []string{ "123:", "0:", ":1234", - "1234:1234:1234", "qwerty", "asdfgh:asdfgh", "1234:asdfgh", @@ -32,17 +30,25 @@ var outOfRangePorts = []string{ "0:65536", } +var tooManyColonsPorts = []string{ + "1234:1234:1234", +} + func TestStreamPort(t *testing.T) { for _, port := range validPorts { _, err := ValidateStreamPort(port) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) } for _, port := range invalidPorts { _, err := ValidateStreamPort(port) - ExpectError2(t, port, E.ErrInvalid, err.Error()) + ExpectError2(t, port, strconv.ErrSyntax, err) } for _, port := range outOfRangePorts { _, err := ValidateStreamPort(port) - ExpectError2(t, port, E.ErrOutOfRange, err.Error()) + ExpectError2(t, port, ErrPortOutOfRange, err) + } + for _, port := range tooManyColonsPorts { + _, err := ValidateStreamPort(port) + ExpectError2(t, port, ErrStreamPortTooManyColons, err) } } diff --git a/internal/proxy/fields/stream_scheme.go b/internal/proxy/fields/stream_scheme.go index d195a29..0c0180e 100644 --- a/internal/proxy/fields/stream_scheme.go +++ b/internal/proxy/fields/stream_scheme.go @@ -12,22 +12,23 @@ type StreamScheme struct { ProxyScheme Scheme `json:"proxy"` } -func ValidateStreamScheme(s string) (ss *StreamScheme, err E.Error) { - ss = &StreamScheme{} +func ValidateStreamScheme(s string) (*StreamScheme, error) { + ss := &StreamScheme{} parts := strings.Split(s, ":") if len(parts) == 1 { parts = []string{s, s} } else if len(parts) != 2 { - return nil, E.Invalid("stream scheme", s) + return nil, ErrInvalidScheme.Subject(s) } - ss.ListeningScheme, err = NewScheme(parts[0]) - if err.HasError() { - return nil, err - } - ss.ProxyScheme, err = NewScheme(parts[1]) - if err.HasError() { + + var lErr, pErr error + ss.ListeningScheme, lErr = NewScheme(parts[0]) + ss.ProxyScheme, pErr = NewScheme(parts[1]) + + if err := E.Join(lErr, pErr); err != nil { return nil, err } + return ss, nil } diff --git a/internal/proxy/fields/stream_scheme_test.go b/internal/proxy/fields/stream_scheme_test.go new file mode 100644 index 0000000..15227c0 --- /dev/null +++ b/internal/proxy/fields/stream_scheme_test.go @@ -0,0 +1,37 @@ +package fields + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +var ( + validStreamSchemes = []string{ + "tcp:tcp", + "tcp:udp", + "udp:tcp", + "udp:udp", + "tcp", + "udp", + } + + invalidStreamSchemes = []string{ + "tcp:tcp:", + "tcp:", + ":udp:", + ":udp", + "top", + } +) + +func TestNewStreamScheme(t *testing.T) { + for _, s := range validStreamSchemes { + _, err := ValidateStreamScheme(s) + ExpectNoError(t, err) + } + for _, s := range invalidStreamSchemes { + _, err := ValidateStreamScheme(s) + ExpectError(t, ErrInvalidScheme, err) + } +} diff --git a/internal/route/http.go b/internal/route/http.go index 0010e3c..f2b40a7 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -7,13 +7,13 @@ import ( "strings" "sync" - "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/api/v1/errorpage" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" "github.com/yusing/go-proxy/internal/proxy/entry" PT "github.com/yusing/go-proxy/internal/proxy/fields" "github.com/yusing/go-proxy/internal/task" @@ -33,6 +33,8 @@ type ( rp *gphttp.ReverseProxy task task.Task + + l zerolog.Logger } SubdomainKey = PT.Alias @@ -88,6 +90,10 @@ func NewHTTPRoute(entry *entry.ReverseProxyEntry) (impl, E.Error) { ReverseProxyEntry: entry, rp: rp, task: task.DummyTask(), + l: logger.With(). + Str("type", string(entry.Scheme)). + Str("name", string(entry.Alias)). + Logger(), } return r, nil } @@ -107,11 +113,11 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { defer httpRoutesMu.Unlock() if !entry.UseHealthCheck(r) && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { - logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) + r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled") if r.HealthCheck == nil { r.HealthCheck = new(health.HealthCheckConfig) } - r.HealthCheck.Disable = true + r.HealthCheck.Disable = false } switch { @@ -143,7 +149,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { if r.HealthMon != nil { if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { - logrus.Warn(E.FailWith("health monitor", err)) + E.LogWarn("health monitor error", err, &r.l) } } @@ -209,15 +215,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) { // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. if err != nil { if !middleware.ServeStaticErrorPageFile(w, r) { - logrus.Error(E.Failure("request"). - Subjectf("%s %s", r.Method, r.URL.String()). - With(err)) + logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) if ok { w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "text/html; charset=utf-8") if _, err := w.Write(errorPage); err != nil { - logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err) + logger.Err(err).Msg("failed to write error page") } } else { http.Error(w, err.Error(), http.StatusNotFound) diff --git a/internal/route/logger.go b/internal/route/logger.go new file mode 100644 index 0000000..caf8f51 --- /dev/null +++ b/internal/route/logger.go @@ -0,0 +1,5 @@ +package route + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "route").Logger() diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go index 440e47b..c1dca89 100755 --- a/internal/route/provider/docker.go +++ b/internal/route/provider/docker.go @@ -6,77 +6,90 @@ import ( "strings" "github.com/docker/docker/client" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" - D "github.com/yusing/go-proxy/internal/docker" + "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/proxy/entry" - R "github.com/yusing/go-proxy/internal/route" - W "github.com/yusing/go-proxy/internal/watcher" + "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/utils/strutils" + "github.com/yusing/go-proxy/internal/watcher" ) type DockerProvider struct { name, dockerHost string ExplicitOnly bool + l zerolog.Logger } var ( AliasRefRegex = regexp.MustCompile(`#\d+`) AliasRefRegexOld = regexp.MustCompile(`\$\d+`) + + ErrAliasRefIndexOutOfRange = E.New("index out of range") ) -func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.Error) { +func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, error) { if dockerHost == common.DockerHostFromEnv { dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost) } - return &DockerProvider{name, dockerHost, explicitOnly}, nil + return &DockerProvider{ + name, + dockerHost, + explicitOnly, + logger.With().Str("type", "docker").Str("name", name).Logger(), + }, nil } func (p *DockerProvider) String() string { return "docker@" + p.name } -func (p *DockerProvider) NewWatcher() W.Watcher { - return W.NewDockerWatcher(p.dockerHost) +func (p *DockerProvider) Logger() *zerolog.Logger { + return &p.l } -func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.Error) { - routes = R.NewRoutes() +func (p *DockerProvider) NewWatcher() watcher.Watcher { + return watcher.NewDockerWatcher(p.dockerHost) +} + +func (p *DockerProvider) loadRoutesImpl() (route.Routes, E.Error) { + routes := route.NewRoutes() entries := entry.NewProxyEntries() - containers, err := D.ListContainers(p.dockerHost) + containers, err := docker.ListContainers(p.dockerHost) if err != nil { - return routes, err + return routes, E.From(err) } - errors := E.NewBuilder("errors in docker labels") + errs := E.NewBuilder("") for _, c := range containers { - container := D.FromDocker(&c, p.dockerHost) + container := docker.FromDocker(&c, p.dockerHost) if container.IsExcluded { continue } newEntries, err := p.entriesFromContainerLabels(container) if err != nil { - errors.Add(err) + errs.Add(err.Subject(container.ContainerName)) } // although err is not nil // there may be some valid entries in `en` dups := entries.MergeFrom(newEntries) // add the duplicate proxy entries to the error dups.RangeAll(func(k string, v *entry.RawEntry) { - errors.Addf("duplicate alias %s", k) + errs.Addf("duplicated alias %s", k) }) } - routes, err = R.FromEntries(entries) - errors.Add(err) + routes, err = route.FromEntries(entries) + errs.Add(err) - return routes, errors.Build() + return routes, errs.Error() } -func (p *DockerProvider) shouldIgnore(container *D.Container) bool { +func (p *DockerProvider) shouldIgnore(container *docker.Container) bool { return container.IsExcluded || !container.IsExplicit && p.ExplicitOnly || !container.IsExplicit && container.IsDatabase || @@ -85,7 +98,7 @@ func (p *DockerProvider) shouldIgnore(container *D.Container) bool { // Returns a list of proxy entries for a container. // Always non-nil. -func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries entry.RawEntries, _ E.Error) { +func (p *DockerProvider) entriesFromContainerLabels(container *docker.Container) (entries entry.RawEntries, _ E.Error) { entries = entry.NewProxyEntries() if p.shouldIgnore(container) { @@ -100,9 +113,9 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent }) } - errors := E.NewBuilder("failed to apply label") + errs := E.NewBuilder("label errors") for key, val := range container.Labels { - errors.Add(p.applyLabel(container, entries, key, val)) + errs.Add(p.applyLabel(container, entries, key, val)) } // remove all entries that failed to fill in missing fields @@ -110,59 +123,56 @@ func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (ent re.FillMissingFields() }) - return entries, errors.Build().Subject(container.ContainerName) + return entries, errs.Error() } -func (p *DockerProvider) applyLabel(container *D.Container, entries entry.RawEntries, key, val string) (res E.Error) { - b := E.NewBuilder("errors in label %s", key) - defer b.To(&res) +func (p *DockerProvider) applyLabel(container *docker.Container, entries entry.RawEntries, key, val string) E.Error { + lbl := docker.ParseLabel(key, val) + if lbl.Namespace != docker.NSProxy { + return nil + } + if lbl.Target == docker.WildcardAlias { + // apply label for all aliases + labelErrs := entries.CollectErrors(func(a string, e *entry.RawEntry) error { + return docker.ApplyLabel(e, lbl) + }) + if err := E.Join(labelErrs...); err != nil { + return err.Subject(lbl.Target) + } + return nil + } - refErr := E.NewBuilder("errors in alias references") + refErrs := E.NewBuilder("alias ref errors") replaceIndexRef := func(ref string) string { - index, err := strconv.Atoi(ref[1:]) + index, err := strutils.Atoi(ref[1:]) if err != nil { - refErr.Add(E.Invalid("integer", ref)) + refErrs.Add(err) return ref } if index < 1 || index > len(container.Aliases) { - refErr.Add(E.OutOfRange("index", ref)) + refErrs.Add(ErrAliasRefIndexOutOfRange.Subject(strconv.Itoa(index))) return ref } return container.Aliases[index-1] } - lbl, err := D.ParseLabel(key, val) - if err != nil { - b.Add(err.Subject(key)) + lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef) + lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(ref string) string { + p.l.Warn().Msgf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#")) + return replaceIndexRef(ref) + }) + if refErrs.HasError() { + return refErrs.Error().Subject(lbl.String()) } - if lbl.Namespace != D.NSProxy { - return + + en, ok := entries.Load(lbl.Target) + if !ok { + en = &entry.RawEntry{ + Alias: lbl.Target, + Container: container, + } + entries.Store(lbl.Target, en) } - if lbl.Target == D.WildcardAlias { - // apply label for all aliases - entries.RangeAll(func(a string, e *entry.RawEntry) { - if err = D.ApplyLabel(e, lbl); err != nil { - b.Add(err) - } - }) - } else { - lbl.Target = AliasRefRegex.ReplaceAllStringFunc(lbl.Target, replaceIndexRef) - lbl.Target = AliasRefRegexOld.ReplaceAllStringFunc(lbl.Target, func(s string) string { - logrus.Warnf("%q should now be %q, old syntax will be removed in a future version", lbl, strings.ReplaceAll(lbl.String(), "$", "#")) - return replaceIndexRef(s) - }) - if refErr.HasError() { - b.Add(refErr.Build()) - return - } - config, ok := entries.Load(lbl.Target) - if !ok { - b.Add(E.NotExist("alias", lbl.Target)) - return - } - if err = D.ApplyLabel(config, lbl); err != nil { - b.Add(err) - } - } - return + + return docker.ApplyLabel(en, lbl) } diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index 43d3cc0..dfd0558 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -1,8 +1,8 @@ package provider import ( - "strings" "testing" + "time" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/network" @@ -20,7 +20,7 @@ var ( p DockerProvider ) -func TestApplyLabelWildcard(t *testing.T) { +func TestApplyLabel(t *testing.T) { pathPatterns := ` - / - POST /upload/{$} @@ -61,9 +61,13 @@ func TestApplyLabelWildcard(t *testing.T) { "proxy.a.middlewares.middleware1.prop2": "value2", "proxy.a.middlewares.middleware2.prop3": "value3", "proxy.a.middlewares.middleware2.prop4": "value4", + "proxy.a.homepage.show": "true", + "proxy.a.homepage.icon": "png/example.png", + "proxy.a.healthcheck.path": "/ping", + "proxy.a.healthcheck.interval": "10s", }, }, client.DefaultDockerHost)) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) a, ok := entries.Load("a") ExpectTrue(t, ok) @@ -102,6 +106,12 @@ func TestApplyLabelWildcard(t *testing.T) { ExpectEqual(t, a.Container.StopSignal, "SIGTERM") ExpectEqual(t, b.Container.StopSignal, "SIGTERM") + + ExpectEqual(t, a.Homepage.Show, true) + ExpectEqual(t, a.Homepage.Icon, "png/example.png") + + ExpectEqual(t, a.HealthCheck.Path, "/ping") + ExpectEqual(t, a.HealthCheck.Interval, 10*time.Second) } func TestApplyLabelWithAlias(t *testing.T) { @@ -123,7 +133,7 @@ func TestApplyLabelWithAlias(t *testing.T) { c, ok := entries.Load("c") ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, a.Scheme, "http") ExpectEqual(t, a.Port, "3333") ExpectEqual(t, a.NoTLSVerify, true) @@ -133,7 +143,7 @@ func TestApplyLabelWithAlias(t *testing.T) { } func TestApplyLabelWithRef(t *testing.T) { - entries := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, State: "running", Labels: map[string]string{ @@ -171,8 +181,7 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { }, }, "") _, err := p.entriesFromContainerLabels(c) - ExpectError(t, E.ErrOutOfRange, err.Error()) - ExpectTrue(t, strings.Contains(err.String(), "index out of range")) + ExpectError(t, ErrAliasRefIndexOutOfRange, err) _, err = p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, @@ -182,13 +191,33 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { "proxy.#0.host": "localhost", }, }, "")) - ExpectError(t, E.ErrOutOfRange, err.Error()) - ExpectTrue(t, strings.Contains(err.String(), "index out of range")) + ExpectError(t, ErrAliasRefIndexOutOfRange, err) +} + +func TestDynamicAliases(t *testing.T) { + c := D.FromDocker(&types.Container{ + Names: []string{"app1"}, + State: "running", + Labels: map[string]string{ + "proxy.app1.port": "1234", + "proxy.app1_backend.port": "5678", + }, + }, client.DefaultDockerHost) + + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("app1") + ExpectTrue(t, ok) + ExpectEqual(t, raw.Scheme, "http") + ExpectEqual(t, raw.Port, "1234") + + raw, ok = E.Must(p.entriesFromContainerLabels(c)).Load("app1_backend") + ExpectTrue(t, ok) + ExpectEqual(t, raw.Scheme, "http") + ExpectEqual(t, raw.Port, "5678") } func TestPublicIPLocalhost(t *testing.T) { c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, client.DefaultDockerHost) - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PublicIP, "127.0.0.1") ExpectEqual(t, raw.Host, raw.Container.PublicIP) @@ -196,7 +225,7 @@ func TestPublicIPLocalhost(t *testing.T) { func TestPublicIPRemote(t *testing.T) { c := D.FromDocker(&types.Container{Names: dummyNames, State: "running"}, "tcp://1.2.3.4:2375") - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") ExpectEqual(t, raw.Host, raw.Container.PublicIP) @@ -213,7 +242,7 @@ func TestPrivateIPLocalhost(t *testing.T) { }, }, }, client.DefaultDockerHost) - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PrivateIP, "172.17.0.123") ExpectEqual(t, raw.Host, raw.Container.PrivateIP) @@ -231,7 +260,7 @@ func TestPrivateIPRemote(t *testing.T) { }, }, }, "tcp://1.2.3.4:2375") - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) ExpectEqual(t, raw.Container.PrivateIP, "") ExpectEqual(t, raw.Container.PublicIP, "1.2.3.4") @@ -260,9 +289,9 @@ func TestStreamDefaultValues(t *testing.T) { t.Run("local", func(t *testing.T) { c := D.FromDocker(cont, client.DefaultDockerHost) - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - en := Must(entry.ValidateEntry(raw)) + en := E.Must(entry.ValidateEntry(raw)) a := ExpectType[*entry.StreamEntry](t, en) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) @@ -273,9 +302,9 @@ func TestStreamDefaultValues(t *testing.T) { t.Run("remote", func(t *testing.T) { c := D.FromDocker(cont, "tcp://1.2.3.4:2375") - raw, ok := Must(p.entriesFromContainerLabels(c)).Load("a") + raw, ok := E.Must(p.entriesFromContainerLabels(c)).Load("a") ExpectTrue(t, ok) - en := Must(entry.ValidateEntry(raw)) + en := E.Must(entry.ValidateEntry(raw)) a := ExpectType[*entry.StreamEntry](t, en) ExpectEqual(t, a.Scheme.ListeningScheme, T.Scheme("udp")) ExpectEqual(t, a.Scheme.ProxyScheme, T.Scheme("udp")) @@ -286,7 +315,7 @@ func TestStreamDefaultValues(t *testing.T) { } func TestExplicitExclude(t *testing.T) { - _, ok := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + _, ok := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ D.LabelAliases: "a", @@ -299,7 +328,7 @@ func TestExplicitExclude(t *testing.T) { func TestImplicitExcludeDatabase(t *testing.T) { t.Run("mount path detection", func(t *testing.T) { - _, ok := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + _, ok := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Mounts: []types.MountPoint{ {Source: "/data", Destination: "/var/lib/postgresql/data"}, @@ -308,7 +337,7 @@ func TestImplicitExcludeDatabase(t *testing.T) { ExpectFalse(t, ok) }) t.Run("exposed port detection", func(t *testing.T) { - _, ok := Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ + _, ok := E.Must(p.entriesFromContainerLabels(D.FromDocker(&types.Container{ Names: dummyNames, Ports: []types.Port{ {Type: "tcp", PrivatePort: 5432, PublicPort: 5432}, @@ -317,58 +346,3 @@ func TestImplicitExcludeDatabase(t *testing.T) { ExpectFalse(t, ok) }) } - -// func TestImplicitExcludeNoExposedPort(t *testing.T) { -// var p DockerProvider -// entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ -// Image: "redis", -// Names: []string{"redis"}, -// Ports: []types.Port{ -// {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed -// }, -// State: "running", -// }, "")) -// ExpectNoError(t, err.Error()) - -// _, ok := entries.Load("redis") -// ExpectFalse(t, ok) -// } - -// func TestNotExcludeSpecifiedPort(t *testing.T) { -// var p DockerProvider -// entries, err := p.entriesFromContainerLabels(D.FromDocker(&types.Container{ -// Image: "redis", -// Names: []string{"redis"}, -// Ports: []types.Port{ -// {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed -// }, -// Labels: map[string]string{ -// "proxy.redis.port": "6379:6379", // but specified in label -// }, -// }, "")) -// ExpectNoError(t, err.Error()) - -// _, ok := entries.Load("redis") -// ExpectTrue(t, ok) -// } - -// func TestNotExcludeNonExposedPortHostNetwork(t *testing.T) { -// var p DockerProvider -// cont := &types.Container{ -// Image: "redis", -// Names: []string{"redis"}, -// Ports: []types.Port{ -// {Type: "tcp", PrivatePort: 6379, PublicPort: 0}, // not exposed -// }, -// Labels: map[string]string{ -// "proxy.redis.port": "6379:6379", -// }, -// } -// cont.HostConfig.NetworkMode = "host" - -// entries, err := p.entriesFromContainerLabels(D.FromDocker(cont, "")) -// ExpectNoError(t, err.Error()) - -// _, ok := entries.Load("redis") -// ExpectTrue(t, ok) -// } diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 0118cd2..6c990ce 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -12,25 +12,27 @@ import ( type EventHandler struct { provider *Provider - added []string - removed []string - paused []string - updated []string - errs E.Builder + errs *E.Builder + added *E.Builder + removed *E.Builder + updated *E.Builder } func (provider *Provider) newEventHandler() *EventHandler { return &EventHandler{ provider: provider, errs: E.NewBuilder("event errors"), + added: E.NewBuilder("added"), + removed: E.NewBuilder("removed"), + updated: E.NewBuilder("updated"), } } func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { oldRoutes := handler.provider.routes - newRoutes, err := handler.provider.LoadRoutesImpl() + newRoutes, err := handler.provider.loadRoutesImpl() if err != nil { - handler.errs.Add(err.Subject("load routes")) + handler.errs.Add(err) if newRoutes.Size() == 0 { return } @@ -41,17 +43,19 @@ func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { for _, event := range events { eventsLog.Addf("event %s, actor: name=%s, id=%s", event.Action, event.ActorName, event.ActorID) } - handler.provider.l.Debug(eventsLog.String()) + E.LogDebug(eventsLog.About(), eventsLog.Error(), handler.provider.Logger()) + oldRoutesLog := E.NewBuilder("old routes") - oldRoutes.RangeAll(func(k string, r *route.Route) { - oldRoutesLog.Addf(k) + oldRoutes.RangeAllParallel(func(k string, r *route.Route) { + oldRoutesLog.Adds(k) }) - handler.provider.l.Debug(oldRoutesLog.String()) + E.LogDebug(oldRoutesLog.About(), oldRoutesLog.Error(), handler.provider.Logger()) + newRoutesLog := E.NewBuilder("new routes") - newRoutes.RangeAll(func(k string, r *route.Route) { - newRoutesLog.Addf(k) + newRoutes.RangeAllParallel(func(k string, r *route.Route) { + newRoutesLog.Adds(k) }) - handler.provider.l.Debug(newRoutesLog.String()) + E.LogDebug(newRoutesLog.About(), newRoutesLog.Error(), handler.provider.Logger()) } oldRoutes.RangeAll(func(k string, oldr *route.Route) { @@ -95,41 +99,35 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool func (handler *EventHandler) Add(parent task.Task, route *route.Route) { err := handler.provider.startRoute(parent, route) if err != nil { - handler.errs.Add(E.FailWith("add "+route.Entry.Alias, err)) + handler.errs.Add(err.Subject("add")) } else { - handler.added = append(handler.added, route.Entry.Alias) + handler.added.Adds(route.Entry.Alias) } } func (handler *EventHandler) Remove(route *route.Route) { route.Finish("route removed") handler.provider.routes.Delete(route.Entry.Alias) - handler.removed = append(handler.removed, route.Entry.Alias) + handler.removed.Adds(route.Entry.Alias) } func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) { oldRoute.Finish("route update") err := handler.provider.startRoute(parent, newRoute) if err != nil { - handler.errs.Add(E.FailWith("update "+newRoute.Entry.Alias, err)) + handler.errs.Add(err.Subject("update")) } else { - handler.updated = append(handler.updated, newRoute.Entry.Alias) + handler.updated.Adds(newRoute.Entry.Alias) } } func (handler *EventHandler) Log() { results := E.NewBuilder("event occured") - for _, alias := range handler.added { - results.Addf("added %s", alias) - } - for _, alias := range handler.removed { - results.Addf("removed %s", alias) - } - for _, alias := range handler.updated { - results.Addf("updated %s", alias) - } - results.Add(handler.errs.Build()) - if result := results.Build(); result != nil { - handler.provider.l.Info(result) + results.Add(handler.added.Error()) + results.Add(handler.removed.Error()) + results.Add(handler.updated.Error()) + results.Add(handler.errs.Error()) + if result := results.String(); result != "" { + handler.provider.Logger().Info().Msg(result) } } diff --git a/internal/route/provider/file.go b/internal/route/provider/file.go index 8f302e0..754b680 100644 --- a/internal/route/provider/file.go +++ b/internal/route/provider/file.go @@ -1,11 +1,10 @@ package provider import ( - "errors" "os" "path" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/proxy/entry" @@ -17,53 +16,49 @@ import ( type FileProvider struct { fileName string path string + l zerolog.Logger } -func FileProviderImpl(filename string) (ProviderImpl, E.Error) { +func FileProviderImpl(filename string) (ProviderImpl, error) { impl := &FileProvider{ fileName: filename, path: path.Join(common.ConfigBasePath, filename), + l: logger.With().Str("type", "file").Str("name", filename).Logger(), } _, err := os.Stat(impl.path) - switch { - case err == nil: - return impl, nil - case errors.Is(err, os.ErrNotExist): - return nil, E.NotExist("file", impl.path) - default: - return nil, E.UnexpectedError(err) + if err != nil { + return nil, err } + return impl, nil } func Validate(data []byte) E.Error { return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data) } -func (p FileProvider) String() string { +func (p *FileProvider) String() string { return p.fileName } -func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.Error) { - routes = R.NewRoutes() - - b := E.NewBuilder("validation failure") - defer b.To(&res) +func (p *FileProvider) Logger() *zerolog.Logger { + return &p.l +} +func (p *FileProvider) loadRoutesImpl() (R.Routes, E.Error) { + routes := R.NewRoutes() entries := entry.NewProxyEntries() - data, err := E.Check(os.ReadFile(p.path)) + data, err := os.ReadFile(p.path) if err != nil { - b.Add(E.FailWith("read file", err)) - return + return routes, E.From(err) } - if err = entries.UnmarshalFromYAML(data); err != nil { - b.Add(err) - return + if err := entries.UnmarshalFromYAML(data); err != nil { + return routes, E.From(err) } if err := Validate(data); err != nil { - logrus.Warn(err) + E.LogWarn(p.fileName+": validation failure", err) } return R.FromEntries(entries) diff --git a/internal/route/provider/logger.go b/internal/route/provider/logger.go new file mode 100644 index 0000000..d734627 --- /dev/null +++ b/internal/route/provider/logger.go @@ -0,0 +1,5 @@ +package provider + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "provider").Logger() diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 195d910..e6cb01e 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -1,11 +1,12 @@ package provider import ( + "errors" "fmt" "path" "time" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/task" @@ -22,13 +23,12 @@ type ( routes R.Routes watcher W.Watcher - - l *logrus.Entry } ProviderImpl interface { fmt.Stringer + loadRoutesImpl() (R.Routes, E.Error) NewWatcher() W.Watcher - LoadRoutesImpl() (R.Routes, E.Error) + Logger() *zerolog.Logger } ProviderType string ProviderStats struct { @@ -45,20 +45,22 @@ const ( providerEventFlushInterval = 300 * time.Millisecond ) +var ( + ErrEmptyProviderName = errors.New("empty provider name") +) + func newProvider(name string, t ProviderType) *Provider { - p := &Provider{ + return &Provider{ name: name, t: t, routes: R.NewRoutes(), } - p.l = logrus.WithField("provider", p) - return p } -func NewFileProvider(filename string) (p *Provider, err E.Error) { +func NewFileProvider(filename string) (p *Provider, err error) { name := path.Base(filename) if name == "" { - return nil, E.Invalid("file name", "empty") + return nil, ErrEmptyProviderName } p = newProvider(name, ProviderTypeFile) p.ProviderImpl, err = FileProviderImpl(filename) @@ -69,9 +71,9 @@ func NewFileProvider(filename string) (p *Provider, err E.Error) { return } -func NewDockerProvider(name string, dockerHost string) (p *Provider, err E.Error) { +func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) { if name == "" { - return nil, E.Invalid("provider name", "empty") + return nil, ErrEmptyProviderName } p = newProvider(name, ProviderTypeDocker) @@ -106,7 +108,7 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { if err != nil { p.routes.Delete(r.Entry.Alias) subtask.Finish(err) // just to ensure - return err + return err.Subject(r.Entry.Alias) } else { p.routes.Store(r.Entry.Alias, r) subtask.OnFinished("del from provider", func() { @@ -117,16 +119,14 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { } // Start implements task.TaskStarter. -func (p *Provider) Start(configSubtask task.Task) (res E.Error) { - errors := E.NewBuilder("errors starting routes") - defer errors.To(&res) - +func (p *Provider) Start(configSubtask task.Task) E.Error { // routes and event queue will stop on parent cancel providerTask := configSubtask - p.routes.RangeAllParallel(func(alias string, r *R.Route) { - errors.Add(p.startRoute(providerTask, r)) - }) + errs := p.routes.CollectErrorsParallel( + func(alias string, r *R.Route) error { + return p.startRoute(providerTask, r) + }) eventQueue := events.NewEventQueue( providerTask, @@ -139,11 +139,15 @@ func (p *Provider) Start(configSubtask task.Task) (res E.Error) { flushTask.Finish("events flushed") }, func(err E.Error) { - p.l.Error(err) + E.LogError("event error", err, p.Logger()) }, ) eventQueue.Start(p.watcher.Events(providerTask.Context())) - return + + if err := E.Join(errs...); err != nil { + return err.Subject(p.String()) + } + return nil } func (p *Provider) RangeRoutes(do func(string, *R.Route)) { @@ -156,14 +160,14 @@ func (p *Provider) GetRoute(alias string) (*R.Route, bool) { func (p *Provider) LoadRoutes() E.Error { var err E.Error - p.routes, err = p.LoadRoutesImpl() + p.routes, err = p.loadRoutesImpl() if p.routes.Size() > 0 { return err } if err == nil { return nil } - return E.FailWith("loading routes", err) + return err } func (p *Provider) NumRoutes() int { diff --git a/internal/route/raw.go b/internal/route/raw.go deleted file mode 100644 index f206a74..0000000 --- a/internal/route/raw.go +++ /dev/null @@ -1,94 +0,0 @@ -package route - -import ( - "errors" - "fmt" - "net" - "time" - - T "github.com/yusing/go-proxy/internal/proxy/fields" - U "github.com/yusing/go-proxy/internal/utils" -) - -type ( - RawStream struct { - *StreamRoute - - listener net.Listener - targetAddr net.Addr - } -) - -const ( - streamBufferSize = 8192 - streamDialTimeout = 5 * time.Second -) - -func NewRawStreamRoute(base *StreamRoute) *RawStream { - return &RawStream{ - StreamRoute: base, - } -} - -func (route *RawStream) Setup() error { - var lcfg net.ListenConfig - var err error - - switch route.Scheme.ListeningScheme { - case "tcp": - route.targetAddr, err = net.ResolveTCPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) - if err != nil { - return err - } - tcpListener, err := lcfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) - if err != nil { - return err - } - route.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) - route.listener = tcpListener - case "udp": - route.targetAddr, err = net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) - if err != nil { - return err - } - udpListener, err := lcfg.ListenPacket(route.task.Context(), "udp", fmt.Sprintf(":%v", route.Port.ListeningPort)) - if err != nil { - return err - } - route.Port.ListeningPort = T.Port(udpListener.LocalAddr().(*net.UDPAddr).Port) - route.listener = newUDPListenerAdaptor(route.task.Context(), udpListener) - default: - return errors.New("invalid listening scheme: " + string(route.Scheme.ListeningScheme)) - } - - return nil -} - -func (route *RawStream) Accept() (net.Conn, error) { - if route.listener == nil { - return nil, errors.New("listener not yet set up") - } - return route.listener.Accept() -} - -func (route *RawStream) Handle(c net.Conn) error { - clientConn := c.(net.Conn) - - defer clientConn.Close() - route.task.OnCancel("close conn", func() { clientConn.Close() }) - - dialer := &net.Dialer{Timeout: streamDialTimeout} - - serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) - serverConn, err := dialer.DialContext(route.task.Context(), string(route.Scheme.ProxyScheme), serverAddr) - if err != nil { - return err - } - - pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) - return pipe.Start() -} - -func (route *RawStream) Close() error { - return route.listener.Close() -} diff --git a/internal/route/route.go b/internal/route/route.go index 93b3a4c..8cace61 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -89,5 +89,5 @@ func FromEntries(entries entry.RawEntries) (Routes, E.Error) { } }) - return routes, b.Build() + return routes, b.Error() } diff --git a/internal/route/stream.go b/internal/route/stream.go index 957a166..c94d75b 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -6,7 +6,7 @@ import ( "fmt" "sync" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" net "github.com/yusing/go-proxy/internal/net/types" @@ -18,13 +18,14 @@ import ( type StreamRoute struct { *entry.StreamEntry - net.Stream `json:"-"` + + stream net.Stream `json:"-"` HealthMon health.HealthMonitor `json:"health"` task task.Task - l logrus.FieldLogger + l zerolog.Logger } var ( @@ -39,11 +40,15 @@ func GetStreamProxies() F.Map[string, *StreamRoute] { func NewStreamRoute(entry *entry.StreamEntry) (impl, E.Error) { // TODO: support non-coherent scheme if !entry.Scheme.IsCoherent() { - return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) + return nil, E.Errorf("unsupported scheme: %v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme) } return &StreamRoute{ StreamEntry: entry, task: task.DummyTask(), + l: logger.With(). + Str("type", string(entry.Scheme.ListeningScheme)). + Str("name", entry.TargetName()). + Logger(), }, nil } @@ -62,57 +67,54 @@ func (r *StreamRoute) Start(providerSubtask task.Task) E.Error { defer streamRoutesMu.Unlock() if r.HealthCheck.Disable && (entry.UseLoadBalance(r) || entry.UseIdleWatcher(r)) { - logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled", r.Alias) - r.HealthCheck.Disable = true + r.l.Error().Msg("healthCheck.disabled cannot be false when loadbalancer or idlewatcher is enabled") + r.HealthCheck.Disable = false } - // if r.Scheme.ListeningScheme.IsTCP() { - // r.Stream = NewTCPRoute(r) - // } else { - // r.Stream = NewUDPRoute(r) - // } r.task = providerSubtask - r.Stream = NewRawStreamRoute(r) - r.l = logrus.WithField("route", r.Stream.String()) + r.stream = NewStream(r) switch { case entry.UseIdleWatcher(r): wakerTask := providerSubtask.Parent().Subtask("waker for " + string(r.Alias)) - waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.Stream) + waker, err := idlewatcher.NewStreamWaker(wakerTask, r.StreamEntry, r.stream) if err != nil { r.task.Finish(err) return err } - r.Stream = waker + r.stream = waker r.HealthMon = waker case entry.UseHealthCheck(r): r.HealthMon = health.NewRawHealthMonitor(r.TargetURL(), r.HealthCheck) } - if err := r.Setup(); err != nil { + if err := r.stream.Setup(); err != nil { r.task.Finish(err) - return E.FailWith("setup", err) + return E.From(err) } r.task.OnFinished("close stream", func() { - if err := r.Close(); err != nil { - r.l.Error("close stream error: ", err) + if err := r.stream.Close(); err != nil { + E.LogError("close stream failed", err, &r.l) } }) - r.task.OnFinished("remove from route table", func() { - streamRoutes.Delete(string(r.Alias)) - }) - r.l.Infof("listening on %s port %d", r.Scheme.ListeningScheme, r.Port.ListeningPort) + r.l.Info(). + Str("proto", string(r.Scheme.ListeningScheme)). + Int("port", int(r.Port.ListeningPort)). + Msg("listening") if r.HealthMon != nil { if err := r.HealthMon.Start(r.task.Subtask("health monitor")); err != nil { - logrus.Warn("health monitor error: ", err) + E.LogWarn("health monitor error", err, &r.l) } } go r.acceptConnections() streamRoutes.Store(string(r.Alias), r) + r.task.OnFinished("remove from route table", func() { + streamRoutes.Delete(string(r.Alias)) + }) return nil } @@ -128,25 +130,28 @@ func (r *StreamRoute) acceptConnections() { case <-r.task.Context().Done(): return default: - conn, err := r.Accept() + conn, err := r.stream.Accept() if err != nil { select { case <-r.task.Context().Done(): default: - r.l.Error("accept connection error: ", err) - r.task.Finish(err) + E.LogError("accept connection error", err, &r.l) } + r.task.Finish(err) return } - connTask := r.task.Subtask(fmt.Sprintf("connection from %s", conn.RemoteAddr())) + if conn == nil { + panic("connection is nil") + } + connTask := r.task.Subtask("connection") go func() { - err := r.Handle(conn) + err := r.stream.Handle(conn) if err != nil && !errors.Is(err, context.Canceled) { - r.l.Error(err) + E.LogError("handle connection error", err, &r.l) + connTask.Finish(err) } else { - connTask.Finish("connection closed") + connTask.Finish("closed") } - conn.Close() }() } } diff --git a/internal/route/stream_impl.go b/internal/route/stream_impl.go new file mode 100644 index 0000000..6b29dc9 --- /dev/null +++ b/internal/route/stream_impl.go @@ -0,0 +1,115 @@ +package route + +import ( + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/yusing/go-proxy/internal/net/types" + T "github.com/yusing/go-proxy/internal/proxy/fields" + U "github.com/yusing/go-proxy/internal/utils" +) + +type ( + Stream struct { + *StreamRoute + + listener types.StreamListener + targetAddr net.Addr + } +) + +const ( + streamFirstConnBufferSize = 128 + streamDialTimeout = 5 * time.Second +) + +func NewStream(base *StreamRoute) *Stream { + return &Stream{ + StreamRoute: base, + } +} + +func (stream *Stream) Addr() net.Addr { + if stream.listener == nil { + panic("listener is nil") + } + return stream.listener.Addr() +} + +func (stream *Stream) Setup() error { + var lcfg net.ListenConfig + var err error + + switch stream.Scheme.ListeningScheme { + case "tcp": + stream.targetAddr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%v", stream.Host, stream.Port.ProxyPort)) + if err != nil { + return err + } + tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", fmt.Sprintf(":%v", stream.Port.ListeningPort)) + if err != nil { + return err + } + stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) + stream.listener = types.NetListener(tcpListener) + case "udp": + stream.targetAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", stream.Host, stream.Port.ProxyPort)) + if err != nil { + return err + } + udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", fmt.Sprintf(":%v", stream.Port.ListeningPort)) + if err != nil { + return err + } + udpConn, ok := udpListener.(*net.UDPConn) + if !ok { + udpListener.Close() + return errors.New("udp listener is not *net.UDPConn") + } + stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port) + stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr) + default: + panic("should not reach here") + } + + return nil +} + +func (stream *Stream) Accept() (types.StreamConn, error) { + if stream.listener == nil { + return nil, errors.New("listener is nil") + } + return stream.listener.Accept() +} + +func (stream *Stream) Handle(conn types.StreamConn) error { + switch conn := conn.(type) { + case *UDPConn: + switch stream := stream.listener.(type) { + case *UDPForwarder: + return stream.Handle(conn) + default: + return fmt.Errorf("unexpected listener type: %T", stream) + } + case io.ReadWriteCloser: + stream.task.OnCancel("close conn", func() { conn.Close() }) + + dialer := &net.Dialer{Timeout: streamDialTimeout} + dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) + if err != nil { + return err + } + defer dstConn.Close() + pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn) + return pipe.Start() + default: + return fmt.Errorf("unexpected conn type: %T", conn) + } +} + +func (stream *Stream) Close() error { + return stream.listener.Close() +} diff --git a/internal/route/tcp.go b/internal/route/tcp.go deleted file mode 100755 index d14b482..0000000 --- a/internal/route/tcp.go +++ /dev/null @@ -1,68 +0,0 @@ -package route - -// import ( -// "context" -// "fmt" -// "net" -// "time" - -// "github.com/yusing/go-proxy/internal/net/types" -// T "github.com/yusing/go-proxy/internal/proxy/fields" -// U "github.com/yusing/go-proxy/internal/utils" -// F "github.com/yusing/go-proxy/internal/utils/functional" -// ) - -// const tcpDialTimeout = 5 * time.Second - -// type ( -// TCPConnMap = F.Map[net.Conn, struct{}] -// TCPRoute struct { -// *StreamRoute -// listener *net.TCPListener -// } -// ) - -// func NewTCPRoute(base *StreamRoute) *TCPRoute { -// return &TCPRoute{StreamRoute: base} -// } - -// func (route *TCPRoute) Setup() error { -// var cfg net.ListenConfig -// in, err := cfg.Listen(route.task.Context(), "tcp", fmt.Sprintf(":%v", route.Port.ListeningPort)) -// if err != nil { -// return err -// } -// //! this read the allocated port from original ':0' -// route.Port.ListeningPort = T.Port(in.Addr().(*net.TCPAddr).Port) -// route.listener = in.(*net.TCPListener) -// return nil -// } - -// func (route *TCPRoute) Accept() (types.StreamConn, error) { -// return route.listener.Accept() -// } - -// func (route *TCPRoute) Handle(c types.StreamConn) error { -// clientConn := c.(net.Conn) - -// defer clientConn.Close() -// route.task.OnCancel("close conn", func() { clientConn.Close() }) - -// ctx, cancel := context.WithTimeout(route.task.Context(), tcpDialTimeout) - -// serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) -// dialer := &net.Dialer{} - -// serverConn, err := dialer.DialContext(ctx, string(route.Scheme.ProxyScheme), serverAddr) -// cancel() -// if err != nil { -// return err -// } - -// pipe := U.NewBidirectionalPipe(route.task.Context(), clientConn, serverConn) -// return pipe.Start() -// } - -// func (route *TCPRoute) Close() error { -// return route.listener.Close() -// } diff --git a/internal/route/udp.go b/internal/route/udp.go deleted file mode 100755 index 8f7ae89..0000000 --- a/internal/route/udp.go +++ /dev/null @@ -1,149 +0,0 @@ -package route - -// import ( -// "errors" -// "fmt" -// "io" -// "net" - -// "github.com/yusing/go-proxy/internal/net/types" -// T "github.com/yusing/go-proxy/internal/proxy/fields" -// U "github.com/yusing/go-proxy/internal/utils" -// F "github.com/yusing/go-proxy/internal/utils/functional" -// ) - -// type ( -// UDPRoute struct { -// *StreamRoute - -// connMap UDPConnMap - -// listeningConn net.PacketConn -// targetAddr *net.UDPAddr -// } -// UDPConn struct { -// key string -// src net.Conn -// dst net.Conn -// U.BidirectionalPipe -// } -// UDPConnMap = F.Map[string, *UDPConn] -// ) - -// var NewUDPConnMap = F.NewMap[UDPConnMap] - -// const udpBufferSize = 8192 - -// func NewUDPRoute(base *StreamRoute) *UDPRoute { -// return &UDPRoute{ -// StreamRoute: base, -// connMap: NewUDPConnMap(), -// } -// } - -// func (route *UDPRoute) Setup() error { -// var cfg net.ListenConfig -// source, err := cfg.ListenPacket(route.task.Context(), string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort)) -// if err != nil { -// return err -// } -// raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) -// if err != nil { -// source.Close() -// return err -// } - -// //! this read the allocated listeningPort from original ':0' -// route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) - -// route.listeningConn = source -// route.targetAddr = raddr - -// return nil -// } - -// func (route *UDPRoute) Accept() (types.StreamConn, error) { -// in := route.listeningConn - -// buffer := make([]byte, udpBufferSize) -// nRead, srcAddr, err := in.ReadFrom(buffer) -// if err != nil { -// return nil, err -// } - -// if nRead == 0 { -// return nil, io.ErrShortBuffer -// } - -// key := srcAddr.String() -// conn, ok := route.connMap.Load(key) - -// if !ok { -// srcConn, err := net.Dial(srcAddr.Network(), srcAddr.String()) -// if err != nil { -// return nil, err -// } -// dstConn, err := net.Dial(route.targetAddr.Network(), route.targetAddr.String()) -// if err != nil { -// srcConn.Close() -// return nil, err -// } -// conn = &UDPConn{ -// key, -// srcConn, -// dstConn, -// U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), -// } -// route.connMap.Store(key, conn) -// } - -// _, err = conn.dst.Write(buffer[:nRead]) -// return conn, err -// } - -// func (route *UDPRoute) Handle(c types.StreamConn) error { -// switch c := c.(type) { -// case *UDPConn: -// err := c.Start() -// route.connMap.Delete(c.key) -// c.Close() -// return err -// case *net.TCPConn: -// in := route.listeningConn -// srcConn, err := net.DialTCP("tcp", nil, c.RemoteAddr().(*net.TCPAddr)) -// if err != nil { -// return err -// } -// err = U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, c}, sourceRWCloser{in, srcConn}).Start() -// c.Close() -// return err -// } -// return fmt.Errorf("unknown conn type: %T", c) -// } - -// func (route *UDPRoute) Close() error { -// route.connMap.RangeAllParallel(func(k string, v *UDPConn) { -// v.Close() -// }) -// route.connMap.Clear() -// return route.listeningConn.Close() -// } - -// // Close implements types.StreamConn -// func (conn *UDPConn) Close() error { -// return errors.Join(conn.src.Close(), conn.dst.Close()) -// } - -// // RemoteAddr implements types.StreamConn -// func (conn *UDPConn) RemoteAddr() net.Addr { -// return conn.src.RemoteAddr() -// } - -// type sourceRWCloser struct { -// server net.PacketConn -// net.Conn -// } - -// func (w sourceRWCloser) Write(p []byte) (int, error) { -// return w.server.WriteTo(p, w.RemoteAddr().(*net.UDPAddr)) -// } diff --git a/internal/route/udp_forwarder.go b/internal/route/udp_forwarder.go new file mode 100644 index 0000000..8ed4532 --- /dev/null +++ b/internal/route/udp_forwarder.go @@ -0,0 +1,197 @@ +package route + +import ( + "context" + "fmt" + "net" + "sync" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/types" + F "github.com/yusing/go-proxy/internal/utils/functional" +) + +type ( + UDPForwarder struct { + ctx context.Context + forwarder *net.UDPConn + dstAddr net.Addr + connMap F.Map[string, *UDPConn] + mu sync.Mutex + } + UDPConn struct { + srcAddr *net.UDPAddr + conn net.Conn + buf *UDPBuf + } + UDPBuf struct { + data, oob []byte + n, oobn int + } +) + +const udpConnBufferSize = 4096 + +func NewUDPForwarder(ctx context.Context, forwarder *net.UDPConn, dstAddr net.Addr) *UDPForwarder { + return &UDPForwarder{ + ctx: ctx, + forwarder: forwarder, + dstAddr: dstAddr, + connMap: F.NewMapOf[string, *UDPConn](), + } +} + +func newUDPBuf() *UDPBuf { + return &UDPBuf{ + data: make([]byte, udpConnBufferSize), + oob: make([]byte, udpConnBufferSize), + } +} + +func (conn *UDPConn) SrcAddrString() string { + return conn.srcAddr.Network() + "://" + conn.srcAddr.String() +} + +func (conn *UDPConn) DstAddrString() string { + return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String() +} + +func (w *UDPForwarder) Addr() net.Addr { + return w.forwarder.LocalAddr() +} + +func (w *UDPForwarder) Accept() (types.StreamConn, error) { + buf := newUDPBuf() + addr, err := w.readFromListener(buf) + if err != nil { + return nil, err + } + return &UDPConn{ + srcAddr: addr, + buf: buf, + }, nil +} + +func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) { + switch dstAddr := w.dstAddr.(type) { + case *net.UDPAddr: + var laddr *net.UDPAddr + if dstAddr.IP.IsLoopback() { + laddr, _ = net.ResolveUDPAddr(dstAddr.Network(), "127.0.0.1:") + } + dstConn, err = net.DialUDP(w.dstAddr.Network(), laddr, dstAddr) + case *net.TCPAddr: + dstConn, err = net.DialTCP(w.dstAddr.Network(), nil, dstAddr) + default: + err = fmt.Errorf("unsupported network %s", w.dstAddr.Network()) + } + return +} + +func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) { + buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob) + if err == nil { + logger.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn) + } + return +} + +func (dst *UDPConn) read() (err error) { + switch dstConn := dst.conn.(type) { + case *net.UDPConn: + dst.buf.n, dst.buf.oobn, _, _, err = dstConn.ReadMsgUDP(dst.buf.data, dst.buf.oob) + default: + dst.buf.n, err = dstConn.Read(dst.buf.data[:dst.buf.n]) + dst.buf.oobn = 0 + } + if err == nil { + logger.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", dst.DstAddrString(), dst.buf.n, dst.buf.oobn) + } + return +} + +func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) { + buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr) + if err == nil { + logger.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn) + } + return +} + +func (dst *UDPConn) write() (err error) { + switch dstConn := dst.conn.(type) { + case *net.UDPConn: + dst.buf.n, dst.buf.oobn, err = dstConn.WriteMsgUDP(dst.buf.data[:dst.buf.n], dst.buf.oob[:dst.buf.oobn], nil) + if err == nil { + logger.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", dst.DstAddrString(), dst.buf.n, dst.buf.oobn) + } + default: + _, err = dstConn.Write(dst.buf.data[:dst.buf.n]) + if err == nil { + logger.Debug().Msgf("write to dst %s success (n: %d)", dst.DstAddrString(), dst.buf.n) + } + } + + return nil +} + +func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { + conn, ok := streamConn.(*UDPConn) + if !ok { + panic("unexpected conn type") + } + key := conn.srcAddr.String() + + w.mu.Lock() + dst, ok := w.connMap.Load(key) + if !ok { + var err error + dst = conn + dst.conn, err = w.dialDst() + if err != nil { + return err + } + if err := dst.write(); err != nil { + dst.conn.Close() + return err + } + w.connMap.Store(key, dst) + } else { + conn.conn = dst.conn + if err := conn.write(); err != nil { + w.connMap.Delete(key) + dst.conn.Close() + return err + } + } + w.mu.Unlock() + + for { + select { + case <-w.ctx.Done(): + return nil + default: + if err := dst.read(); err != nil { + w.connMap.Delete(key) + dst.conn.Close() + return err + } + + if err := w.writeToSrc(dst.srcAddr, dst.buf); err != nil { + return err + } + } + } +} + +func (w *UDPForwarder) Close() error { + errs := E.NewBuilder("errors closing udp conn") + w.mu.Lock() + defer w.mu.Unlock() + w.connMap.RangeAll(func(key string, conn *UDPConn) { + errs.Add(conn.conn.Close()) + }) + w.connMap.Clear() + errs.Add(w.forwarder.Close()) + return errs.Error() +} diff --git a/internal/route/udp_listener.go b/internal/route/udp_listener.go deleted file mode 100644 index 04f3b65..0000000 --- a/internal/route/udp_listener.go +++ /dev/null @@ -1,73 +0,0 @@ -package route - -import ( - "context" - "io" - "net" - "sync" - - F "github.com/yusing/go-proxy/internal/utils/functional" -) - -type ( - UDPListener struct { - ctx context.Context - listener net.PacketConn - connMap UDPConnMap - mu sync.Mutex - } - UDPConnMap = F.Map[string, net.Conn] -) - -var NewUDPConnMap = F.NewMap[UDPConnMap] - -func newUDPListenerAdaptor(ctx context.Context, listener net.PacketConn) net.Listener { - return &UDPListener{ - ctx: ctx, - listener: listener, - connMap: NewUDPConnMap(), - } -} - -// Addr implements net.Listener. -func (route *UDPListener) Addr() net.Addr { - return route.listener.LocalAddr() -} - -func (udpl *UDPListener) Accept() (net.Conn, error) { - in := udpl.listener - - buffer := make([]byte, streamBufferSize) - nRead, srcAddr, err := in.ReadFrom(buffer) - if err != nil { - return nil, err - } - - if nRead == 0 { - return nil, io.ErrShortBuffer - } - - udpl.mu.Lock() - defer udpl.mu.Unlock() - - key := srcAddr.String() - conn, ok := udpl.connMap.Load(key) - if !ok { - dialer := &net.Dialer{Timeout: streamDialTimeout} - srcConn, err := dialer.DialContext(udpl.ctx, srcAddr.Network(), srcAddr.String()) - if err != nil { - return nil, err - } - udpl.connMap.Store(key, srcConn) - } - return conn, nil -} - -// Close implements net.Listener. -func (route *UDPListener) Close() error { - route.connMap.RangeAllParallel(func(key string, conn net.Conn) { - conn.Close() - }) - route.connMap.Clear() - return route.listener.Close() -} diff --git a/internal/server/server.go b/internal/server/server.go index 0d917ec..f6f4bf5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "errors" + "io" "log" + "net/http" "time" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/autocert" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" ) @@ -23,6 +26,8 @@ type Server struct { startTime time.Time task task.Task + + l zerolog.Logger } type Options struct { @@ -34,22 +39,11 @@ type Options struct { Handler http.Handler } -type LogrusWrapper struct { - *logrus.Entry -} - -func (l LogrusWrapper) Write(b []byte) (int, error) { - return l.Logger.WriterLevel(logrus.ErrorLevel).Write(b) -} - func NewServer(opt Options) (s *Server) { var httpSer, httpsSer *http.Server var httpHandler http.Handler - logger := log.Default() - logger.SetOutput(LogrusWrapper{ - logrus.WithFields(logrus.Fields{"?": "server", "name": opt.Name}), - }) + logger := logging.With().Str("module", "server").Str("name", opt.Name).Logger() certAvailable := false if opt.CertProvider != nil { @@ -67,14 +61,14 @@ func NewServer(opt Options) (s *Server) { httpSer = &http.Server{ Addr: opt.HTTPAddr, Handler: httpHandler, - ErrorLog: logger, + ErrorLog: log.New(io.Discard, "", 0), // most are tls related } } if certAvailable && opt.HTTPSAddr != "" { httpsSer = &http.Server{ Addr: opt.HTTPSAddr, Handler: opt.Handler, - ErrorLog: logger, + ErrorLog: log.New(io.Discard, "", 0), // most are tls related TLSConfig: &tls.Config{ GetCertificate: opt.CertProvider.GetCert, }, @@ -86,6 +80,7 @@ func NewServer(opt Options) (s *Server) { http: httpSer, https: httpsSer, task: task.GlobalTask(opt.Name + " server"), + l: logger, } } @@ -101,19 +96,19 @@ func (s *Server) Start() { s.startTime = time.Now() if s.http != nil { - s.httpStarted = true - logrus.Printf("starting http %s server on %s", s.Name, s.http.Addr) go func() { s.handleErr("http", s.http.ListenAndServe()) }() + s.httpStarted = true + s.l.Info().Str("addr", s.http.Addr).Msg("server started") } if s.https != nil { - s.httpsStarted = true - logrus.Printf("starting https %s server on %s", s.Name, s.https.Addr) go func() { s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath())) }() + s.httpsStarted = true + s.l.Info().Str("addr", s.https.Addr).Msgf("server started") } s.task.OnFinished("stop server", s.stop) @@ -144,7 +139,7 @@ func (s *Server) handleErr(scheme string, err error) { case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): return default: - logrus.Fatalf("%s server %s error: %s", scheme, s.Name, err) + s.l.Fatal().Err(err).Str("scheme", scheme).Msg("server error") } } @@ -162,5 +157,3 @@ func redirectToTLSHandler(port string) http.HandlerFunc { http.Redirect(w, r, r.URL.String(), redirectCode) } } - -var logger = logrus.WithField("module", "server") diff --git a/internal/task/task.go b/internal/task/task.go index 1ff7400..d848136 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -4,14 +4,14 @@ import ( "context" "errors" "fmt" - "runtime" "strings" "sync" "time" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -100,7 +100,7 @@ type ( subtasks F.Set[*task] subTasksWg sync.WaitGroup - name, line string + name string OnFinishedFuncs []func() OnFinishedMu sync.Mutex @@ -113,6 +113,8 @@ type ( var ( ErrProgramExiting = errors.New("program exiting") ErrTaskCanceled = errors.New("task canceled") + + logger = logging.With().Str("module", "task").Logger() ) // GlobalTask returns a new Task with the given name, derived from the global context. @@ -159,14 +161,22 @@ func GlobalContextWait(timeout time.Duration) { case <-done: return case <-after: - logrus.Warn("Timeout waiting for these tasks to finish:\n" + globalTask.tree()) + logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree()) return } } } +func (t *task) trace() *zerolog.Event { + return logger.Trace().Str("name", t.name) +} + func (t *task) Name() string { - return t.name + if !common.IsTrace { + return t.name + } + parts := strings.Split(t.name, " > ") + return parts[len(parts)-1] } func (t *task) String() string { @@ -212,20 +222,18 @@ func (t *task) OnFinished(about string, fn func()) { onCompTask := GlobalTask(t.name + " > OnFinished > " + about) go t.runAllOnFinished(onCompTask) } - var file string - var line int - if common.IsTrace { - _, file, line, _ = runtime.Caller(1) - } idx := len(t.OnFinishedFuncs) wrapped := func() { defer func() { if err := recover(); err != nil { - logrus.Errorf("panic in %s > OnFinished[%d]: %q\nline %s:%d\n%v", t.name, idx, about, file, line, err) + logger.Error(). + Str("name", t.name). + Interface("err", err). + Msg("panic in " + about) } }() fn() - logrus.Tracef("line %s:%d\n%s > OnFinished[%d] done: %s", file, line, t.name, idx, about) + logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about) } t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped) } @@ -236,7 +244,7 @@ func (t *task) OnCancel(about string, fn func()) { <-t.ctx.Done() fn() onCompTask.Finish("done") - logrus.Tracef("%s > onCancel done: %s", t.name, about) + t.trace().Msg("onCancel done: " + about) }() } @@ -276,14 +284,10 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n parent.subTasksWg.Add(1) parent.subtasks.Add(subtask) if common.IsTrace { - _, file, line, ok := runtime.Caller(3) - if ok { - subtask.line = fmt.Sprintf("%s:%d", file, line) - } - logrus.Tracef("line %s\n%s started", subtask.line, name) + subtask.trace().Msg("started") go func() { subtask.Wait() - logrus.Tracef("%s finished: %s", subtask.Name(), subtask.FinishCause()) + subtask.trace().Msg("finished: " + subtask.FinishCause().Error()) }() } go func() { @@ -324,12 +328,6 @@ func (t *task) tree(prefix ...string) string { pre = prefix[0] sb.WriteString(pre + "- ") } - if t.line != "" { - sb.WriteString("line " + t.line + "\n") - if len(pre) > 0 { - sb.WriteString(pre + "- ") - } - } sb.WriteString(t.Name() + "\n") t.subtasks.RangeAll(func(subtask *task) { sb.WriteString(subtask.tree(pre + " ")) @@ -341,7 +339,6 @@ func (t *task) tree(prefix ...string) string { // // The map contains the following keys: // - name: the name of the task -// - line: the line number of the task, if available // - subtasks: a slice of maps, each representing a subtask // // The subtask maps contain the same keys, recursively. @@ -354,11 +351,8 @@ func (t *task) tree(prefix ...string) string { // only. func (t *task) serialize() map[string]any { m := make(map[string]any) - parts := strings.Split(t.name, ">") - m["name"] = strings.TrimSpace(parts[len(parts)-1]) - if t.line != "" { - m["line"] = t.line - } + parts := strings.Split(t.name, " > ") + m["name"] = parts[len(parts)-1] if t.subtasks.Size() > 0 { m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size()) t.subtasks.RangeAll(func(subtask *task) { diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index 56669d1..8d95b17 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -1,10 +1,10 @@ package functional import ( + "errors" "sync" "github.com/puzpuzpuz/xsync/v3" - E "github.com/yusing/go-proxy/internal/error" "gopkg.in/yaml.v3" ) @@ -12,6 +12,8 @@ type Map[KT comparable, VT any] struct { *xsync.MapOf[KT, VT] } +const minParallelSize = 4 + func NewMapOf[KT comparable, VT any](options ...func(*xsync.MapConfig)) Map[KT, VT] { return Map[KT, VT]{xsync.NewMapOf[KT, VT](options...)} } @@ -113,6 +115,11 @@ func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) { // // nothing func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { + if m.Size() < minParallelSize { + m.RangeAll(do) + return + } + var wg sync.WaitGroup m.Range(func(k KT, v VT) bool { @@ -126,6 +133,45 @@ func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { wg.Wait() } +// CollectErrors calls the given function for each key-value pair in the map, +// then returns a slice of errors collected. +func (m Map[KT, VT]) CollectErrors(do func(k KT, v VT) error) []error { + errs := make([]error, 0) + m.Range(func(k KT, v VT) bool { + if err := do(k, v); err != nil { + errs = append(errs, err) + } + return true + }) + return errs +} + +// CollectErrors calls the given function for each key-value pair in the map, +// then returns a slice of errors collected. +func (m Map[KT, VT]) CollectErrorsParallel(do func(k KT, v VT) error) []error { + if m.Size() < minParallelSize { + return m.CollectErrors(do) + } + + errs := make([]error, 0) + mu := sync.Mutex{} + wg := sync.WaitGroup{} + m.Range(func(k KT, v VT) bool { + wg.Add(1) + go func() { + if err := do(k, v); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + wg.Done() + }() + return true + }) + wg.Wait() + return errs +} + // RemoveAll removes all key-value pairs from the map where the value matches the given criteria. // // Parameters: @@ -160,12 +206,12 @@ func (m Map[KT, VT]) Has(k KT) bool { // Returns: // // error: if the unmarshaling fails -func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.Error { +func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) error { if m.Size() != 0 { - return E.FailedWhy("unmarshal from yaml", "map is not empty") + return errors.New("cannot unmarshal into non-empty map") } tmp := make(map[KT]VT) - if err := E.From(yaml.Unmarshal(data, tmp)); err != nil { + if err := yaml.Unmarshal(data, tmp); err != nil { return err } for k, v := range tmp { diff --git a/internal/utils/io.go b/internal/utils/io.go index ea79772..c60ff33 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -83,20 +83,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re } } -func (p BidirectionalPipe) Start() error { +func (p BidirectionalPipe) Start() E.Error { var wg sync.WaitGroup wg.Add(2) b := E.NewBuilder("bidirectional pipe error") go func() { - b.AddE(p.pSrcDst.Start()) + b.Add(p.pSrcDst.Start()) wg.Done() }() go func() { - b.AddE(p.pDstSrc.Start()) + b.Add(p.pDstSrc.Start()) wg.Done() }() wg.Wait() - return b.Build().Error() + return b.Error() } // Copyright 2009 The Go Authors. All rights reserved. @@ -152,18 +152,18 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) } -func LoadJSON[T any](path string, pointer *T) E.Error { - data, err := E.Check(os.ReadFile(path)) - if err.HasError() { +func LoadJSON[T any](path string, pointer *T) error { + data, err := os.ReadFile(path) + if err != nil { return err } - return E.From(json.Unmarshal(data, pointer)) + return json.Unmarshal(data, pointer) } -func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.Error { - data, err := E.Check(json.Marshal(pointer)) - if err.HasError() { +func SaveJSON[T any](path string, pointer *T, perm os.FileMode) error { + data, err := json.Marshal(pointer) + if err != nil { return err } - return E.From(os.WriteFile(path, data, perm)) + return os.WriteFile(path, data, perm) } diff --git a/internal/utils/nearest_field.go b/internal/utils/nearest_field.go new file mode 100644 index 0000000..01cae8b --- /dev/null +++ b/internal/utils/nearest_field.go @@ -0,0 +1,49 @@ +package utils + +import ( + "reflect" + + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +func NearestField(input string, s any) string { + minDistance := -1 + nearestField := "" + var fields []string + switch s := s.(type) { + case []string: + fields = s + default: + t := reflect.TypeOf(s) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + fields = make([]string, 0) + for i := 0; i < t.NumField(); i++ { + jsonTag, ok := t.Field(i).Tag.Lookup("json") + if ok { + fields = append(fields, jsonTag) + } else { + fields = append(fields, t.Field(i).Name) + } + } + } else if t.Kind() == reflect.Map { + keys := reflect.ValueOf(s).MapKeys() + fields = make([]string, len(keys)) + for i, key := range keys { + fields[i] = key.String() + } + } else { + panic("unsupported type: " + t.String()) + } + } + for _, field := range fields { + distance := strutils.LevenshteinDistance(input, field) + if minDistance == -1 || distance < minDistance { + minDistance = distance + nearestField = field + } + } + return nearestField +} diff --git a/internal/utils/schema.go b/internal/utils/schema.go index dd66849..0ca0099 100644 --- a/internal/utils/schema.go +++ b/internal/utils/schema.go @@ -1,15 +1,23 @@ package utils import ( + "sync" + "github.com/santhosh-tekuri/jsonschema" ) var ( schemaCompiler = jsonschema.NewCompiler() schemaStorage = make(map[string]*jsonschema.Schema) + schemaMu sync.Mutex ) func GetSchema(path string) *jsonschema.Schema { + if schema, ok := schemaStorage[path]; ok { + return schema + } + schemaMu.Lock() + defer schemaMu.Unlock() if schema, ok := schemaStorage[path]; ok { return schema } diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 7565a5e..7ab9ce4 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "errors" - "fmt" "reflect" "strconv" "strings" @@ -13,6 +12,7 @@ import ( "github.com/santhosh-tekuri/jsonschema" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" "gopkg.in/yaml.v3" ) @@ -23,17 +23,27 @@ type ( } ) +var ( + ErrInvalidType = E.New("invalid type") + ErrNilValue = E.New("nil") + ErrUnsettable = E.New("unsettable") + ErrUnsupportedConvertion = E.New("unsupported convertion") + ErrMapMissingColon = E.New("map missing colon") + ErrMapTooManyColons = E.New("map too many colons") + ErrUnknownField = E.New("unknown field") +) + func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { var i any err := yaml.Unmarshal(data, &i) if err != nil { - return E.FailWith("unmarshal yaml", err) + return E.From(err) } m, err := json.Marshal(i) if err != nil { - return E.FailWith("marshal json", err) + return E.From(err) } err = schema.Validate(bytes.NewReader(m)) @@ -43,14 +53,14 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { var valErr *jsonschema.ValidationError if !errors.As(err, &valErr) { - return E.UnexpectedError(err) + panic(err) } b := E.NewBuilder("yaml validation error") for _, e := range valErr.Causes { - b.Addf(e.Message) + b.Adds(e.Message) } - return b.Build() + return b.Error() } // Serialize converts the given data into a map[string]any representation. @@ -66,7 +76,7 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.Error { // Returns: // - result: The resulting map[string]any representation of the data. // - error: An error if the data type is unsupported or if there is an error during conversion. -func Serialize(data any) (SerializedObject, E.Error) { +func Serialize(data any) (SerializedObject, error) { result := make(map[string]any) // Use reflection to inspect the data type @@ -74,7 +84,7 @@ func Serialize(data any) (SerializedObject, E.Error) { // Check if the value is valid if !value.IsValid() { - return nil, E.Invalid("data", fmt.Sprintf("type: %T", data)) + return nil, ErrInvalidType.Subjectf("%T", data) } // Dereference pointers if necessary @@ -123,7 +133,7 @@ func Serialize(data any) (SerializedObject, E.Error) { } } default: - return nil, E.Unsupported("type", value.Kind()) + return nil, errors.New("serialize: unsupported data type " + value.Kind().String()) } return result, nil @@ -139,11 +149,10 @@ func Serialize(data any) (SerializedObject, E.Error) { // The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. func Deserialize(src SerializedObject, dst any) E.Error { if src == nil { - return E.Invalid("src", "nil") + return E.Errorf("deserialize: src is %w", ErrNilValue) } - if dst == nil { - return E.Invalid("nil dst", fmt.Sprintf("type: %T", dst)) + return E.Errorf("deserialize: dst is %w", ErrNilValue) } dstV := reflect.ValueOf(dst) @@ -151,7 +160,7 @@ func Deserialize(src SerializedObject, dst any) E.Error { if dstV.Kind() == reflect.Ptr { if dstV.IsNil() { - return E.Invalid("nil dst", fmt.Sprintf("type: %T", dst)) + return E.Errorf("deserialize: dst is %w", ErrNilValue) } dstV = dstV.Elem() dstT = dstV.Type() @@ -161,7 +170,7 @@ func Deserialize(src SerializedObject, dst any) E.Error { // convert target fields to lower no-snake // then check if the field of data is in the target - // TODO: use E.Builder to collect errors from all fields + errs := E.NewBuilder("deserialize error") switch dstV.Kind() { case reflect.Struct: @@ -173,12 +182,13 @@ func Deserialize(src SerializedObject, dst any) E.Error { if field, ok := mapping[ToLowerNoSnake(k)]; ok { err := Convert(reflect.ValueOf(v), field) if err != nil { - return err.Subject(k) + errs.Add(err.Subject(k)) } } else { - return E.Unexpected("field", k).Subjectf("%T", dst) + errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, dst)))) } } + return errs.Error() case reflect.Map: if dstV.IsNil() { dstV.Set(reflect.MakeMap(dstT)) @@ -187,15 +197,14 @@ func Deserialize(src SerializedObject, dst any) E.Error { tmp := reflect.New(dstT.Elem()).Elem() err := Convert(reflect.ValueOf(src[k]), tmp) if err != nil { - return err.Subject(k) + errs.Add(err.Subject(k)) } dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp) } + return errs.Error() default: - return E.Unsupported("target type", fmt.Sprintf("%T", dst)) + return ErrUnsupportedConvertion.Subject("deserialize to " + dstT.String()) } - - return nil } // Convert attempts to convert the src to dst. @@ -220,7 +229,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { } if !dst.CanSet() { - return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface())) + return ErrUnsettable.Subject(dstT.String()) } if dst.Kind() == reflect.Pointer { @@ -241,12 +250,12 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { case srcT.Kind() == reflect.Map: obj, ok := src.Interface().(SerializedObject) if !ok { - return E.TypeMismatch[SerializedObject](src.Interface()) + return ErrUnsupportedConvertion.Subject(dstT.String() + " to " + srcT.String()) } return Deserialize(obj, dst.Addr().Interface()) case srcT.Kind() == reflect.Slice: if dstT.Kind() != reflect.Slice { - return E.TypeError("slice", srcT, dstT) + return ErrUnsupportedConvertion.Subject(dstT.String() + " to slice") } newSlice := reflect.MakeSlice(dstT, 0, src.Len()) i := 0 @@ -271,7 +280,7 @@ func Convert(src reflect.Value, dst reflect.Value) E.Error { var ok bool // check if (*T).Convertor is implemented if converter, ok = dst.Addr().Interface().(Converter); !ok { - return E.TypeError("conversion", srcT, dstT) + return ErrUnsupportedConvertion.Subjectf("%s to %s", srcT, dstT) } return converter.ConvertFrom(src.Interface()) @@ -297,8 +306,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E } d, err := time.ParseDuration(src) if err != nil { - convErr = E.Invalid("duration", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(d)) return @@ -308,24 +316,21 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E case reflect.Bool: b, err := strconv.ParseBool(src) if err != nil { - convErr = E.Invalid("boolean", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(b)) return case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, err := strconv.ParseInt(src, 10, 64) if err != nil { - convErr = E.Invalid("int", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(i).Convert(dst.Type())) return case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i, err := strconv.ParseUint(src, 10, 64) if err != nil { - convErr = E.Invalid("uint", src) - return + return true, E.From(err) } dst.Set(reflect.ValueOf(i).Convert(dst.Type())) return @@ -340,7 +345,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E case reflect.Slice: // one liner is comma separated list if len(lines) == 0 { - dst.Set(reflect.ValueOf(CommaSeperatedList(src))) + dst.Set(reflect.ValueOf(strutils.CommaSeperatedList(src))) return } sl := make([]string, 0, len(lines)) @@ -356,35 +361,36 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.E tmp = sl case reflect.Map: m := make(map[string]string, len(lines)) + errs := E.NewBuilder("invalid map") for i, line := range lines { parts := strings.Split(line, ":") if len(parts) < 2 { - convErr = E.Invalid("map", "missing colon").Subjectf("line#%d", i+1).With(line) - return + errs.Add(ErrMapMissingColon.Subjectf("line %d", i+1)) } if len(parts) > 2 { - convErr = E.Invalid("map", "too many colons").Subjectf("line#%d", i+1).With(line) - return + errs.Add(ErrMapTooManyColons.Subjectf("line %d", i+1)) } k := strings.TrimSpace(parts[0]) v := strings.TrimSpace(parts[1]) m[k] = v } + if errs.HasError() { + return true, errs.Error() + } tmp = m } if tmp == nil { - convertible = false - return + return false, nil } return true, Convert(reflect.ValueOf(tmp), dst) } -func DeserializeJSON(j map[string]string, target any) E.Error { - data, err := E.Check(json.Marshal(j)) +func DeserializeJSON(j map[string]string, target any) error { + data, err := json.Marshal(j) if err != nil { return err } - return E.From(json.Unmarshal(data, target)) + return json.Unmarshal(data, target) } func ToLowerNoSnake(s string) string { diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 632d852..29433ee 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -1,6 +1,7 @@ package utils import ( + "errors" "reflect" "testing" @@ -37,14 +38,14 @@ var testStructSerialized = map[string]any{ func TestSerialize(t *testing.T) { s, err := Serialize(testStruct) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, s, testStructSerialized) } func TestDeserialize(t *testing.T) { var s S err := Deserialize(testStructSerialized, &s) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectDeepEqual(t, s, testStruct) } @@ -65,42 +66,42 @@ func TestStringIntConvert(t *testing.T) { ok, err := ConvertString(s, reflect.ValueOf(&test.i8)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i8, int8(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.i16)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i16, int16(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.i32)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i32, int32(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.i64)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.i64, int64(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u8)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u8, uint8(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u16)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u16, uint16(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u32)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u32, uint32(127)) ok, err = ConvertString(s, reflect.ValueOf(&test.u64)) ExpectTrue(t, ok) - ExpectNoError(t, err.Error()) + ExpectNoError(t, err) ExpectEqual(t, test.u64, uint64(127)) } @@ -113,6 +114,8 @@ type testType struct { bar string } +var errInvalid = errors.New("invalid input type") + func (c *testType) ConvertFrom(v any) E.Error { switch v := v.(type) { case string: @@ -122,14 +125,14 @@ func (c *testType) ConvertFrom(v any) E.Error { c.foo = v return nil default: - return E.Invalid("input type", v) + return E.Errorf("%w %T", errInvalid, v) } } func TestConvertor(t *testing.T) { t.Run("string", func(t *testing.T) { m := new(testModel) - ExpectNoError(t, Deserialize(map[string]any{"Test": "bar"}, m).Error()) + ExpectNoError(t, Deserialize(map[string]any{"Test": "bar"}, m)) ExpectEqual(t, m.Test.foo, 0) ExpectEqual(t, m.Test.bar, "bar") @@ -137,7 +140,7 @@ func TestConvertor(t *testing.T) { t.Run("int", func(t *testing.T) { m := new(testModel) - ExpectNoError(t, Deserialize(map[string]any{"Test": 123}, m).Error()) + ExpectNoError(t, Deserialize(map[string]any{"Test": 123}, m)) ExpectEqual(t, m.Test.foo, 123) ExpectEqual(t, m.Test.bar, "") @@ -145,6 +148,6 @@ func TestConvertor(t *testing.T) { t.Run("invalid", func(t *testing.T) { m := new(testModel) - ExpectError(t, E.ErrInvalid, Deserialize(map[string]any{"Test": 123.456}, m).Error()) + ExpectError(t, errInvalid, Deserialize(map[string]any{"Test": 123.456}, m)) }) } diff --git a/internal/utils/string.go b/internal/utils/string.go deleted file mode 100644 index 9cbabc3..0000000 --- a/internal/utils/string.go +++ /dev/null @@ -1,35 +0,0 @@ -package utils - -import ( - "net/url" - "strconv" - "strings" - - "golang.org/x/text/cases" - "golang.org/x/text/language" -) - -func CommaSeperatedList(s string) []string { - res := strings.Split(s, ",") - for i, part := range res { - res[i] = strings.TrimSpace(part) - } - return res -} - -func Title(s string) string { - // TODO: support other languages. - return cases.Title(language.AmericanEnglish).String(s) -} - -func ExtractPort(fullURL string) (int, error) { - url, err := url.Parse(fullURL) - if err != nil { - return 0, err - } - return strconv.Atoi(url.Port()) -} - -func PortString(port uint16) string { - return strconv.FormatUint(uint64(port), 10) -} diff --git a/internal/utils/strutils/ansi/ansi.go b/internal/utils/strutils/ansi/ansi.go new file mode 100644 index 0000000..da89ba9 --- /dev/null +++ b/internal/utils/strutils/ansi/ansi.go @@ -0,0 +1,25 @@ +package ansi + +import "regexp" + +var ansiRegexp = regexp.MustCompile(`\x1b\[[0-9;]*m`) + +const ( + BrightRed = "\x1b[91m" + BrightGreen = "\x1b[92m" + BrightYellow = "\x1b[93m" + BrightCyan = "\x1b[96m" + BrightWhite = "\x1b[97m" + Bold = "\x1b[1m" + Reset = "\x1b[0m" + + HighlightRed = BrightRed + Bold + HighlightGreen = BrightGreen + Bold + HighlightYellow = BrightYellow + Bold + HighlightCyan = BrightCyan + Bold + HighlightWhite = BrightWhite + Bold +) + +func StripANSI(s string) string { + return ansiRegexp.ReplaceAllString(s, "") +} diff --git a/internal/utils/format.go b/internal/utils/strutils/format.go similarity index 88% rename from internal/utils/format.go rename to internal/utils/strutils/format.go index a77875f..c5027fc 100644 --- a/internal/utils/format.go +++ b/internal/utils/strutils/format.go @@ -1,9 +1,11 @@ -package utils +package strutils import ( "fmt" "strings" "time" + + "github.com/yusing/go-proxy/internal/utils/strutils/ansi" ) func FormatDuration(d time.Duration) string { @@ -55,6 +57,10 @@ func ParseBool(s string) bool { } } +func DoYouMean(s string) string { + return "Did you mean " + ansi.HighlightGreen + s + ansi.Reset + "?" +} + func pluralize(n int64) string { if n > 1 { return "s" diff --git a/internal/utils/strutils/strconv.go b/internal/utils/strutils/strconv.go new file mode 100644 index 0000000..a6979dc --- /dev/null +++ b/internal/utils/strutils/strconv.go @@ -0,0 +1,17 @@ +package strutils + +import ( + "errors" + "strconv" + + E "github.com/yusing/go-proxy/internal/error" +) + +func Atoi(s string) (int, E.Error) { + val, err := strconv.Atoi(s) + if err != nil { + return val, E.From(errors.Unwrap(err)).Subject(s) + } + + return val, nil +} diff --git a/internal/utils/strutils/string.go b/internal/utils/strutils/string.go new file mode 100644 index 0000000..a2af3a0 --- /dev/null +++ b/internal/utils/strutils/string.go @@ -0,0 +1,82 @@ +package strutils + +import ( + "net/url" + "strconv" + "strings" + + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +func CommaSeperatedList(s string) []string { + res := strings.Split(s, ",") + for i, part := range res { + res[i] = strings.TrimSpace(part) + } + return res +} + +func Title(s string) string { + return cases.Title(language.AmericanEnglish).String(s) +} + +func ExtractPort(fullURL string) (int, error) { + url, err := url.Parse(fullURL) + if err != nil { + return 0, err + } + return Atoi(url.Port()) +} + +func PortString(port uint16) string { + return strconv.FormatUint(uint64(port), 10) +} + +func LevenshteinDistance(a, b string) int { + if a == b { + return 0 + } + if len(a) == 0 { + return len(b) + } + if len(b) == 0 { + return len(a) + } + + v0 := make([]int, len(b)+1) + v1 := make([]int, len(b)+1) + + for i := 0; i <= len(b); i++ { + v0[i] = i + } + + for i := 0; i < len(a); i++ { + v1[0] = i + 1 + + for j := 0; j < len(b); j++ { + cost := 0 + if a[i] != b[j] { + cost = 1 + } + + v1[j+1] = min(v1[j]+1, v0[j+1]+1, v0[j]+cost) + } + + for j := 0; j <= len(b); j++ { + v0[j] = v1[j] + } + } + + return v1[len(b)] +} + +func min(a, b, c int) int { + if a < b && a < c { + return a + } + if b < a && b < c { + return b + } + return c +} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index 8331d0c..a812560 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils/ansi" ) func init() { @@ -52,6 +52,15 @@ func ExpectEqual[T comparable](t *testing.T, got T, want T) { } } +func ExpectStrEqual(t *testing.T, got string, want string) { + t.Helper() + got = ansi.StripANSI(got) + if got != want { + t.Errorf("expected:\n%v, got\n%v", want, got) + t.FailNow() + } +} + func ExpectEqualAny[T comparable](t *testing.T, got T, wants []T) { t.Helper() for _, want := range wants { @@ -97,17 +106,3 @@ func ExpectType[T any](t *testing.T, got any) (_ T) { } return got.(T) } - -func Must[T any](v T, err E.Error) T { - if err != nil { - panic(err) - } - return v -} - -func Must2[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} diff --git a/internal/watcher/config_file_watcher.go b/internal/watcher/config_file_watcher.go index 4f7c512..31087ea 100644 --- a/internal/watcher/config_file_watcher.go +++ b/internal/watcher/config_file_watcher.go @@ -1,10 +1,10 @@ package watcher import ( - "context" "sync" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/task" ) var ( @@ -17,7 +17,7 @@ func NewConfigFileWatcher(filename string) Watcher { configDirWatcherMu.Lock() defer configDirWatcherMu.Unlock() if configDirWatcher == nil { - configDirWatcher = NewDirectoryWatcher(context.Background(), common.ConfigBasePath) + configDirWatcher = NewDirectoryWatcher(task.GlobalTask("config watcher"), common.ConfigBasePath) } return configDirWatcher.Add(filename) } diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index d729201..d910650 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -7,13 +7,16 @@ import ( "sync" "github.com/fsnotify/fsnotify" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/task" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/watcher/events" ) type DirWatcher struct { + zerolog.Logger + dir string w *fsnotify.Watcher @@ -23,7 +26,7 @@ type DirWatcher struct { eventCh chan Event errCh chan E.Error - ctx context.Context + task task.Task } // NewDirectoryWatcher returns a DirWatcher instance. @@ -34,22 +37,26 @@ type DirWatcher struct { // // Note that the returned DirWatcher is not ready to use until the goroutine // started by NewDirectoryWatcher has finished. -func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher { +func NewDirectoryWatcher(callerSubtask task.Task, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() if err != nil { - logrus.Panicf("unable to create fs watcher: %s", err) + logger.Panic().Err(err).Msg("unable to create fs watcher") } if err = w.Add(dirPath); err != nil { - logrus.Panicf("unable to create fs watcher: %s", err) + logger.Panic().Err(err).Msg("unable to create fs watcher") } helper := &DirWatcher{ + Logger: logger.With(). + Str("type", "dir"). + Str("path", dirPath). + Logger(), dir: dirPath, w: w, fwMap: F.NewMapOf[string, *fileWatcher](), eventCh: make(chan Event), errCh: make(chan E.Error), - ctx: ctx, + task: callerSubtask, } go helper.start() return helper @@ -73,14 +80,10 @@ func (h *DirWatcher) Add(relPath string) Watcher { eventCh: make(chan Event), errCh: make(chan E.Error), } - go func() { - defer func() { - close(s.eventCh) - close(s.errCh) - }() - <-h.ctx.Done() - logrus.Debugf("file watcher %s stopped", relPath) - }() + h.task.OnFinished("close file watcher for "+relPath, func() { + close(s.eventCh) + close(s.errCh) + }) h.fwMap.Store(relPath, s) return s } @@ -88,11 +91,10 @@ func (h *DirWatcher) Add(relPath string) Watcher { func (h *DirWatcher) start() { defer close(h.eventCh) defer h.w.Close() - defer logrus.Debugf("directory watcher %s stopped", h.dir) for { select { - case <-h.ctx.Done(): + case <-h.task.Context().Done(): return case fsEvent, ok := <-h.w.Events: if !ok { @@ -122,9 +124,9 @@ func (h *DirWatcher) start() { // send event to directory watcher select { case h.eventCh <- msg: - logrus.Debugf("sent event to directory watcher %s", h.dir) + h.Debug().Msg("sent event to directory watcher") default: - logrus.Debugf("failed to send event to directory watcher %s", h.dir) + h.Debug().Msg("failed to send event to directory watcher") } // send event to file watcher too @@ -132,12 +134,12 @@ func (h *DirWatcher) start() { if ok { select { case w.eventCh <- msg: - logrus.Debugf("sent event to file watcher %s", relPath) + h.Debug().Msg("sent event to file watcher " + relPath) default: - logrus.Debugf("failed to send event to file watcher %s", relPath) + h.Debug().Msg("failed to send event to file watcher " + relPath) } } else { - logrus.Debugf("file watcher not found: %s", relPath) + h.Debug().Msg("file watcher not found: " + relPath) } case err := <-h.w.Errors: if errors.Is(err, fsnotify.ErrClosed) { diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index 59bbe6b..5cb78b3 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -2,12 +2,11 @@ package watcher import ( "context" - "fmt" "time" docker_events "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/watcher/events" @@ -15,10 +14,11 @@ import ( type ( DockerWatcher struct { + zerolog.Logger + host string client D.Client clientOwned bool - logrus.FieldLogger } DockerListOptions = docker_events.ListOptions ) @@ -47,7 +47,7 @@ var ( dockerWatcherRetryInterval = 3 * time.Second ) -func DockerrFilterContainer(nameOrID string) filters.KeyValuePair { +func DockerFilterContainerNameID(nameOrID string) filters.KeyValuePair { return filters.Arg("container", nameOrID) } @@ -55,18 +55,21 @@ func NewDockerWatcher(host string) DockerWatcher { return DockerWatcher{ host: host, clientOwned: true, - FieldLogger: (logrus. - WithField("module", "docker_watcher"). - WithField("host", host)), + Logger: logger.With(). + Str("type", "docker"). + Str("host", host). + Logger(), } } func NewDockerWatcherWithClient(client D.Client) DockerWatcher { return DockerWatcher{ client: client, - FieldLogger: (logrus. - WithField("module", "docker_watcher"). - WithField("host", client.DaemonHost()))} + Logger: logger.With(). + Str("type", "docker"). + Str("host", client.DaemonHost()). + Logger(), + } } func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Error) { @@ -88,7 +91,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList }() if !w.client.Connected() { - var err E.Error + var err error attempts := 0 for { w.client, err = D.ConnectClient(w.host) @@ -96,7 +99,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList break } attempts++ - errCh <- E.FailWith(fmt.Sprintf("docker connection attempt #%d", attempts), err) + errCh <- E.Errorf("docker connection attempt #%d: %w", attempts, err) select { case <-ctx.Done(): return @@ -113,14 +116,14 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList for { select { case <-ctx.Done(): - if err := E.From(ctx.Err()); err != nil && err.IsNot(context.Canceled) { + if err := E.From(ctx.Err()); err != nil && !err.Is(context.Canceled) { errCh <- err } return case msg := <-cEventCh: action, ok := events.DockerEventMap[msg.Action] if !ok { - w.Debugf("ignored unknown docker event: %s for container %s", msg.Action, msg.Actor.Attributes["name"]) + w.Debug().Msgf("ignored unknown docker event: %s for container %s", msg.Action, msg.Actor.Attributes["name"]) continue } event := Event{ diff --git a/internal/watcher/events/event_queue.go b/internal/watcher/events/event_queue.go index 992166e..d2926f6 100644 --- a/internal/watcher/events/event_queue.go +++ b/internal/watcher/events/event_queue.go @@ -65,7 +65,7 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { go func() { defer func() { if err := recover(); err != nil { - e.onError(E.PanicRecv("onFlush: %s", err).Subject(e.task.Parent().Name())) + e.onError(E.Errorf("recovered panic in onFlush: %v", err).Subject(e.task.Parent().String())) } }() e.onFlush(flushTask, queue) diff --git a/internal/watcher/health/json.go b/internal/watcher/health/json.go index 0c8ec23..9123423 100644 --- a/internal/watcher/health/json.go +++ b/internal/watcher/health/json.go @@ -5,7 +5,7 @@ import ( "time" "github.com/yusing/go-proxy/internal/net/types" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type JSONRepresentation struct { @@ -27,10 +27,10 @@ func (jsonRepr *JSONRepresentation) MarshalJSON() ([]byte, error) { "name": jsonRepr.Name, "config": jsonRepr.Config, "started": jsonRepr.Started.Unix(), - "startedStr": U.FormatTime(jsonRepr.Started), + "startedStr": strutils.FormatTime(jsonRepr.Started), "status": jsonRepr.Status.String(), "uptime": jsonRepr.Uptime.Seconds(), - "uptimeStr": U.FormatDuration(jsonRepr.Uptime), + "uptimeStr": strutils.FormatDuration(jsonRepr.Uptime), "url": url, "extra": jsonRepr.Extra, }) diff --git a/internal/watcher/health/logger.go b/internal/watcher/health/logger.go index 171f4a5..3735a4c 100644 --- a/internal/watcher/health/logger.go +++ b/internal/watcher/health/logger.go @@ -1,5 +1,7 @@ package health -import "github.com/sirupsen/logrus" +import ( + "github.com/yusing/go-proxy/internal/logging" +) -var logger = logrus.WithField("module", "health_mon") +var logger = logging.With().Str("module", "health_mon").Logger() diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go index 41773ec..0a516bf 100644 --- a/internal/watcher/health/monitor.go +++ b/internal/watcher/health/monitor.go @@ -3,10 +3,14 @@ package health import ( "context" "errors" + "fmt" + "strings" "time" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -29,6 +33,10 @@ type ( var monMap = F.NewMapOf[string, HealthMonitor]() +var ( + ErrNegativeInterval = errors.New("negative interval") +) + func newMonitor(url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { mon := &monitor{ config: config, @@ -59,10 +67,12 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error { mon.task = routeSubtask if mon.config.Interval <= 0 { - return E.Invalid("interval", mon.config.Interval) + return E.From(ErrNegativeInterval) } go func() { + logger := logging.With().Str("name", mon.service).Logger() + defer func() { if mon.status.Load() != StatusError { mon.status.Store(StatusUnknown) @@ -70,7 +80,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error { }() if err := mon.checkUpdateHealth(); err != nil { - logger.Errorf("healthchecker %s failure: %s", mon.service, err) + logger.Err(err).Msg("healthchecker failure") return } @@ -87,7 +97,7 @@ func (mon *monitor) Start(routeSubtask task.Task) E.Error { case <-ticker.C: err := mon.checkUpdateHealth() if err != nil { - logger.Errorf("healthchecker %s failure: %s", mon.service, err) + logger.Err(err).Msg("healthchecker failure") return } } @@ -128,10 +138,8 @@ func (mon *monitor) Uptime() time.Duration { // Name implements HealthMonitor. func (mon *monitor) Name() string { - if mon.task == nil { - return "" - } - return mon.task.Name() + parts := strings.Split(mon.service, "/") + return parts[len(parts)-1] } // String implements fmt.Stringer of HealthMonitor. @@ -151,13 +159,14 @@ func (mon *monitor) MarshalJSON() ([]byte, error) { }).MarshalJSON() } -func (mon *monitor) checkUpdateHealth() E.Error { +func (mon *monitor) checkUpdateHealth() error { + logger := logging.With().Str("name", mon.Name()).Logger() healthy, detail, err := mon.checkHealth() if err != nil { defer mon.task.Finish(err) mon.status.Store(StatusError) if !errors.Is(err, context.Canceled) { - return E.Failure("check health").With(err) + return fmt.Errorf("check health: %w", err) } return nil } @@ -169,9 +178,12 @@ func (mon *monitor) checkUpdateHealth() E.Error { } if healthy != (mon.status.Swap(status) == StatusHealthy) { if healthy { - logger.Infof("%s is up", mon.service) + logger.Info().Msg("server is up") + notif.Notify(mon.service, "server is up") } else { - logger.Warnf("%s is down: %s", mon.service, detail) + logger.Warn().Msg("server is down") + logger.Debug().Msg(detail) + notif.Notify(mon.service, "server is down") } } diff --git a/internal/watcher/logger.go b/internal/watcher/logger.go new file mode 100644 index 0000000..ad15593 --- /dev/null +++ b/internal/watcher/logger.go @@ -0,0 +1,5 @@ +package watcher + +import "github.com/yusing/go-proxy/internal/logging" + +var logger = logging.With().Str("module", "watcher").Logger()