From 7184c9cfe985c36f9fa7900d449aa68f99c808ba Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 11 Oct 2024 09:13:38 +0800 Subject: [PATCH] correcting some behaviors for $DOCKER_HOST, now uses container's private IP instead of localhost --- .golangci.yml | 136 +++++++++++++ .trunk/.gitignore | 9 + .trunk/trunk.yaml | 41 ++++ go.mod | 2 +- go.sum | 4 +- internal/api/handler.go | 9 +- internal/api/v1/checkhealth.go | 6 +- .../{error_page => errorpage}/error_page.go | 26 +-- .../{error_page => errorpage}/http_handler.go | 12 +- internal/api/v1/file.go | 6 +- internal/api/v1/index.go | 8 +- internal/api/v1/list.go | 12 +- internal/api/v1/reload.go | 2 +- internal/api/v1/stats.go | 8 +- internal/api/v1/utils/http_client.go | 34 ++-- internal/api/v1/utils/utils.go | 20 +- internal/api/v1/version.go | 3 +- internal/autocert/provider.go | 27 +-- internal/common/constants.go | 14 +- internal/config/config.go | 1 - internal/docker/client.go | 40 ++-- internal/docker/client_info.go | 1 - internal/docker/container.go | 182 ++++++++---------- internal/docker/container_helper.go | 90 +++++++++ internal/docker/idlewatcher/http.go | 4 +- internal/docker/idlewatcher/waker.go | 10 +- internal/docker/idlewatcher/watcher.go | 35 ++-- internal/docker/inspect.go | 6 +- internal/docker/label.go | 2 +- internal/docker/label_test.go | 18 +- internal/docker/logger.go | 5 + internal/docker/proxy_properties.go | 25 --- internal/error/builder.go | 18 +- internal/error/builder_test.go | 6 +- internal/error/error.go | 64 +++--- internal/error/error_test.go | 6 +- internal/list-icons.go | 18 +- .../net/http/loadbalancer/loadbalancer.go | 4 +- .../net/http/middleware/custom_error_page.go | 41 ++-- .../net/http/middleware/middleware_builder.go | 4 +- internal/net/http/middleware/oauth2.go | 8 +- internal/net/http/middleware/test_utils.go | 13 +- internal/net/http/reverse_proxy_mod.go | 13 +- internal/proxy/fields/host.go | 6 +- internal/proxy/fields/path_pattern.go | 16 +- internal/proxy/provider/docker.go | 44 ++--- internal/proxy/provider/docker_test.go | 27 +-- internal/proxy/provider/provider.go | 31 +-- internal/route/constants.go | 6 +- internal/route/http.go | 36 ++-- internal/route/stream.go | 2 +- internal/route/udp.go | 1 - internal/server/instance.go | 10 +- internal/server/server.go | 21 +- internal/setup.go | 32 +-- internal/types/autocert_config.go | 12 +- internal/types/config.go | 12 +- internal/types/proxy_providers.go | 4 +- internal/types/raw_entry.go | 93 +++++---- internal/utils/fs.go | 2 +- internal/utils/functional/map.go | 63 +++++- internal/utils/io.go | 6 +- internal/utils/nocopy.go | 8 + internal/utils/serialization.go | 43 +++-- internal/utils/string.go | 6 +- internal/utils/testing/testing.go | 1 - internal/watcher/config_file_watcher.go | 8 +- internal/watcher/directory_watcher.go | 12 +- 68 files changed, 925 insertions(+), 570 deletions(-) create mode 100644 .golangci.yml create mode 100644 .trunk/.gitignore create mode 100644 .trunk/trunk.yaml rename internal/api/v1/{error_page => errorpage}/error_page.go (74%) rename internal/api/v1/{error_page => errorpage}/http_handler.go (73%) create mode 100644 internal/docker/container_helper.go create mode 100644 internal/docker/logger.go delete mode 100644 internal/docker/proxy_properties.go create mode 100644 internal/utils/nocopy.go diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..7d65a80 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,136 @@ +run: + timeout: 10m + +linters-settings: + govet: + enable-all: true + disable: + - shadow + - fieldalignment + gocyclo: + min-complexity: 14 + goconst: + min-len: 3 + min-occurrences: 4 + misspell: + locale: US + funlen: + lines: -1 + statements: 120 + forbidigo: + forbid: + - ^print(ln)?$ + godox: + keywords: + - FIXME + tagalign: + align: false + sort: true + order: + - description + - json + - toml + - yaml + - yml + - label + - label-slice-as-struct + - file + - kv + - export + stylecheck: + dot-import-whitelist: + - github.com/yusing/go-proxy/internal/utils/testing # go tests only + - github.com/yusing/go-proxy/internal/api/v1/utils # api only + revive: + rules: + - name: struct-tag + - name: blank-imports + - name: context-as-argument + - name: context-keys-type + - name: error-return + - name: error-strings + - name: error-naming + - name: exported + disabled: true + - name: if-return + - name: increment-decrement + - name: var-naming + - name: var-declaration + - name: package-comments + disabled: true + - name: range + - name: receiver-naming + - name: time-naming + - name: unexported-return + - name: indent-error-flow + - name: errorf + - name: empty-block + - name: superfluous-else + - name: unused-parameter + disabled: true + - name: unreachable-code + - name: redefines-builtin-id + gomoddirectives: + replace-allow-list: + - github.com/abbot/go-http-auth + - github.com/gorilla/mux + - github.com/mailgun/minheap + - github.com/mailgun/multibuf + - github.com/jaguilar/vt100 + - github.com/cucumber/godog + - github.com/http-wasm/http-wasm-host-go + testifylint: + disable: + - suite-dont-use-pkg + - require-error + - go-require + staticcheck: + checks: + - all + - -SA1019 + errcheck: + exclude-functions: + - fmt.Fprintln +linters: + enable-all: true + disable: + - execinquery # deprecated + - gomnd # deprecated + - sqlclosecheck # not relevant (SQL) + - rowserrcheck # not relevant (SQL) + - cyclop # duplicate of gocyclo + - depguard # Not relevant + - nakedret # Too strict + - lll # Not relevant + - gocyclo # FIXME must be fixed + - gocognit # Too strict + - nestif # Too many false-positive. + - prealloc # Too many false-positive. + - makezero # Not relevant + - dupl # Too strict + - gosec # Too strict + - gochecknoinits + - gochecknoglobals + - wsl # Too strict + - nlreturn # Not relevant + - mnd # Too strict + - testpackage # Too strict + - tparallel # Not relevant + - paralleltest # Not relevant + - exhaustive # Not relevant + - exhaustruct # Not relevant + - err113 # Too strict + - wrapcheck # Too strict + - noctx # Too strict + - bodyclose # too many false-positive + - forcetypeassert # Too strict + - tagliatelle # Too strict + - varnamelen # Not relevant + - nilnil # Not relevant + - ireturn # Not relevant + - contextcheck # too many false-positive + - containedctx # too many false-positive + - maintidx # kind of duplicate of gocyclo + - nonamedreturns # Too strict + - gosmopolitan # not relevant + - exportloopref # Not relevant since go1.22 diff --git a/.trunk/.gitignore b/.trunk/.gitignore new file mode 100644 index 0000000..15966d0 --- /dev/null +++ b/.trunk/.gitignore @@ -0,0 +1,9 @@ +*out +*logs +*actions +*notifications +*tools +plugins +user_trunk.yaml +user.yaml +tmp diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml new file mode 100644 index 0000000..1c2d6d9 --- /dev/null +++ b/.trunk/trunk.yaml @@ -0,0 +1,41 @@ +# This file controls the behavior of Trunk: https://docs.trunk.io/cli +# To learn more about the format of this file, see https://docs.trunk.io/reference/trunk-yaml +version: 0.1 +cli: + version: 1.22.6 +# Trunk provides extensibility via plugins. (https://docs.trunk.io/plugins) +plugins: + sources: + - id: trunk + ref: v1.6.3 + uri: https://github.com/trunk-io/plugins +# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes) +runtimes: + enabled: + - node@18.12.1 + - python@3.10.8 + - go@1.23.2 +# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration) +lint: + enabled: + - hadolint@2.12.0 + - actionlint@1.7.3 + - checkov@3.2.257 + - git-diff-check + - gofmt@1.20.4 + - golangci-lint@1.61.0 + - markdownlint@0.42.0 + - osv-scanner@1.9.0 + - oxipng@9.1.2 + - prettier@3.3.3 + - shellcheck@0.10.0 + - shfmt@3.6.0 + - trufflehog@3.82.7 + - yamllint@1.35.1 +actions: + disabled: + - trunk-announce + - trunk-check-pre-push + - trunk-fmt-pre-commit + enabled: + - trunk-upgrade-available diff --git a/go.mod b/go.mod index e7a4495..6425cee 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cloudflare/cloudflare-go v0.106.0 // indirect + github.com/cloudflare/cloudflare-go v0.107.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect diff --git a/go.sum b/go.sum index d4f7e02..3c4fe09 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/cloudflare/cloudflare-go v0.106.0 h1:q41gC5Wc1nfi0D1ZhSHokWcd9mGMbqC7RE7qiP+qE00= -github.com/cloudflare/cloudflare-go v0.106.0/go.mod h1:pfUQ4PIG4ISI0/Mmc21Bp86UnFU0ktmPf3iTgbSL+cM= +github.com/cloudflare/cloudflare-go v0.107.0 h1:cMDIw2tzt6TXCJyMFVyP+BPOVkIfMvcKjhMNSNvuEPc= +github.com/cloudflare/cloudflare-go v0.107.0/go.mod h1:5cYGzVBqNTLxMYSLdVjuSs5LJL517wJDSvMPWUrzHzc= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= diff --git a/internal/api/handler.go b/internal/api/handler.go index 29e91d9..414657e 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -5,7 +5,7 @@ import ( "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" - "github.com/yusing/go-proxy/internal/api/v1/error_page" + "github.com/yusing/go-proxy/internal/api/v1/errorpage" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -36,11 +36,11 @@ func NewHandler(cfg *config.Config) http.Handler { mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/stats/ws", wrap(cfg, v1.StatsWS)) - mux.HandleFunc("GET", "/v1/error_page", error_page.GetHandleFunc()) + mux.HandleFunc("GET", "/v1/error_page", errorpage.GetHandleFunc()) return mux } -// allow only requests to API server with host matching common.APIHTTPAddr +// allow only requests to API server with host matching common.APIHTTPAddr. func checkHost(f http.HandlerFunc) http.HandlerFunc { if common.IsDebug { return f @@ -48,8 +48,7 @@ func checkHost(f http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Host != common.APIHTTPAddr { Logger.Warnf("invalid request to API server with host: %s, expect %s", r.Host, common.APIHTTPAddr) - w.WriteHeader(http.StatusForbidden) - w.Write([]byte("invalid request")) + http.Error(w, "invalid request", http.StatusForbidden) return } f(w, r) diff --git a/internal/api/v1/checkhealth.go b/internal/api/v1/checkhealth.go index bba30cd..abd659e 100644 --- a/internal/api/v1/checkhealth.go +++ b/internal/api/v1/checkhealth.go @@ -5,7 +5,7 @@ import ( "net/http" "strings" - U "github.com/yusing/go-proxy/internal/api/v1/utils" + . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/config" R "github.com/yusing/go-proxy/internal/route" ) @@ -13,7 +13,7 @@ import ( func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { target := r.FormValue("target") if target == "" { - U.HandleErr(w, r, U.ErrMissingKey("target"), http.StatusBadRequest) + HandleErr(w, r, ErrMissingKey("target"), http.StatusBadRequest) return } @@ -22,7 +22,7 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { switch { case route == nil: - U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound) + HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound) return case route.Type() == R.RouteTypeReverseProxy: ok = IsSiteHealthy(route.URL().String()) diff --git a/internal/api/v1/error_page/error_page.go b/internal/api/v1/errorpage/error_page.go similarity index 74% rename from internal/api/v1/error_page/error_page.go rename to internal/api/v1/errorpage/error_page.go index a2a46e3..cb796bf 100644 --- a/internal/api/v1/error_page/error_page.go +++ b/internal/api/v1/errorpage/error_page.go @@ -1,4 +1,4 @@ -package error_page +package errorpage import ( "context" @@ -7,7 +7,7 @@ import ( "path" "sync" - api "github.com/yusing/go-proxy/internal/api/v1/utils" + . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -17,6 +17,11 @@ import ( const errPagesBasePath = common.ErrorPagesBasePath +var ( + dirWatcher W.Watcher + fileContentMap = F.NewMapOf[string, []byte]() +) + var setup = sync.OnceFunc(func() { dirWatcher = W.NewDirectoryWatcher(context.Background(), errPagesBasePath) loadContent() @@ -27,7 +32,7 @@ func GetStaticFile(filename string) ([]byte, bool) { return fileContentMap.Load(filename) } -// try .html -> 404.html -> not ok +// try .html -> 404.html -> not ok. func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) { content, ok = fileContentMap.Load(fmt.Sprintf("%d.html", statusCode)) if !ok && statusCode != 404 { @@ -39,7 +44,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) { func loadContent() { files, err := U.ListFiles(errPagesBasePath, 0) if err != nil { - api.Logger.Error(err) + Logger.Error(err) return } for _, file := range files { @@ -48,11 +53,11 @@ func loadContent() { } content, err := os.ReadFile(file) if err != nil { - api.Logger.Errorf("failed to read error page resource %s: %s", file, err) + Logger.Errorf("failed to read error page resource %s: %s", file, err) continue } file = path.Base(file) - api.Logger.Infof("error page resource %s loaded", file) + Logger.Infof("error page resource %s loaded", file) fileContentMap.Store(file, content) } } @@ -72,17 +77,14 @@ func watchDir() { loadContent() case events.ActionFileDeleted: fileContentMap.Delete(filename) - api.Logger.Infof("error page resource %s deleted", filename) + Logger.Infof("error page resource %s deleted", filename) case events.ActionFileRenamed: - api.Logger.Infof("error page resource %s deleted", filename) + Logger.Infof("error page resource %s deleted", filename) fileContentMap.Delete(filename) loadContent() } case err := <-errCh: - api.Logger.Errorf("error watching error page directory: %s", err) + Logger.Errorf("error watching error page directory: %s", err) } } } - -var dirWatcher W.Watcher -var fileContentMap = F.NewMapOf[string, []byte]() diff --git a/internal/api/v1/error_page/http_handler.go b/internal/api/v1/errorpage/http_handler.go similarity index 73% rename from internal/api/v1/error_page/http_handler.go rename to internal/api/v1/errorpage/http_handler.go index 826fd9d..2da9372 100644 --- a/internal/api/v1/error_page/http_handler.go +++ b/internal/api/v1/errorpage/http_handler.go @@ -1,6 +1,10 @@ -package error_page +package errorpage -import "net/http" +import ( + "net/http" + + . "github.com/yusing/go-proxy/internal/api/v1/utils" +) func GetHandleFunc() http.HandlerFunc { setup() @@ -21,5 +25,7 @@ func serveHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "404 not found", http.StatusNotFound) return } - w.Write(content) + if _, err := w.Write(content); err != nil { + HandleErr(w, r, err) + } } diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index e413ae5..9944a2d 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -24,7 +24,7 @@ func GetFileContent(w http.ResponseWriter, r *http.Request) { U.HandleErr(w, r, err) return } - w.Write(content) + U.WriteBody(w, content) } func SetFileContent(w http.ResponseWriter, r *http.Request) { @@ -47,11 +47,11 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { } if validateErr != nil { - U.RespondJson(w, validateErr.JSONObject(), http.StatusBadRequest) + U.RespondJSON(w, r, validateErr.JSONObject(), http.StatusBadRequest) return } - err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644) + err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0o644) if err != nil { U.HandleErr(w, r, err) return diff --git a/internal/api/v1/index.go b/internal/api/v1/index.go index 6d887fa..71bbf2e 100644 --- a/internal/api/v1/index.go +++ b/internal/api/v1/index.go @@ -1,7 +1,11 @@ package v1 -import "net/http" +import ( + "net/http" + + . "github.com/yusing/go-proxy/internal/api/v1/utils" +) func Index(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("API ready")) + WriteBody(w, []byte("API ready")) } diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 392b303..66a4076 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -55,7 +55,7 @@ func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) { } } - U.HandleErr(w, r, U.RespondJson(w, routes)) + U.RespondJSON(w, r, routes) } func listConfigFiles(w http.ResponseWriter, r *http.Request) { @@ -67,21 +67,21 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) { for i := range files { files[i] = strings.TrimPrefix(files[i], common.ConfigBasePath+"/") } - U.HandleErr(w, r, U.RespondJson(w, files)) + U.RespondJSON(w, r, files) } func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) { - U.HandleErr(w, r, U.RespondJson(w, middleware.GetAllTrace())) + U.RespondJSON(w, r, middleware.GetAllTrace()) } func listMiddlewares(w http.ResponseWriter, r *http.Request) { - U.HandleErr(w, r, U.RespondJson(w, middleware.All())) + U.RespondJSON(w, r, middleware.All()) } func listMatchDomains(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - U.HandleErr(w, r, U.RespondJson(w, cfg.Value().MatchDomains)) + U.RespondJSON(w, r, cfg.Value().MatchDomains) } func listHomepageConfig(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - U.HandleErr(w, r, U.RespondJson(w, cfg.HomepageConfig())) + U.RespondJSON(w, r, cfg.HomepageConfig()) } diff --git a/internal/api/v1/reload.go b/internal/api/v1/reload.go index c62c61e..da2c3a5 100644 --- a/internal/api/v1/reload.go +++ b/internal/api/v1/reload.go @@ -9,7 +9,7 @@ import ( func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) { if err := cfg.Reload(); err != nil { - U.RespondJson(w, err.JSONObject(), http.StatusInternalServerError) + U.RespondJSON(w, r, err.JSONObject(), http.StatusInternalServerError) } else { w.WriteHeader(http.StatusOK) } diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index 29bcddc..46ea7f7 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -5,18 +5,17 @@ import ( "net/http" "time" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/utils" - - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" ) func Stats(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - U.HandleErr(w, r, U.RespondJson(w, getStats(cfg))) + U.RespondJSON(w, r, getStats(cfg)) } func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) { @@ -42,6 +41,7 @@ func StatsWS(cfg *config.Config, w http.ResponseWriter, r *http.Request) { U.Logger.Errorf("/stats/ws failed to upgrade websocket: %s", err) return } + /* trunk-ignore(golangci-lint/errcheck) */ defer conn.CloseNow() ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/api/v1/utils/http_client.go b/internal/api/v1/utils/http_client.go index a3afb7b..0cb4ebe 100644 --- a/internal/api/v1/utils/http_client.go +++ b/internal/api/v1/utils/http_client.go @@ -8,20 +8,22 @@ import ( "github.com/yusing/go-proxy/internal/common" ) -var HTTPClient = &http.Client{ - Timeout: common.ConnectionTimeout, - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DisableKeepAlives: true, - ForceAttemptHTTP2: true, - DialContext: (&net.Dialer{ - Timeout: common.DialTimeout, - KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives - }).DialContext, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, -} +var ( + HTTPClient = &http.Client{ + Timeout: common.ConnectionTimeout, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DisableKeepAlives: true, + ForceAttemptHTTP2: true, + DialContext: (&net.Dialer{ + Timeout: common.DialTimeout, + KeepAlive: common.KeepAlive, // this is different from DisableKeepAlives + }).DialContext, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } -var Get = HTTPClient.Get -var Post = HTTPClient.Post -var Head = HTTPClient.Head + Get = HTTPClient.Get + Post = HTTPClient.Post + Head = HTTPClient.Head +) diff --git a/internal/api/v1/utils/utils.go b/internal/api/v1/utils/utils.go index da4199a..c687b44 100644 --- a/internal/api/v1/utils/utils.go +++ b/internal/api/v1/utils/utils.go @@ -5,16 +5,26 @@ import ( "net/http" ) -func RespondJson(w http.ResponseWriter, data any, code ...int) error { +func WriteBody(w http.ResponseWriter, body []byte) { + if _, err := w.Write(body); err != nil { + HandleErr(w, nil, err) + } +} + +func RespondJSON(w http.ResponseWriter, r *http.Request, data any, code ...int) bool { if len(code) > 0 { w.WriteHeader(code[0]) } w.Header().Set("Content-Type", "application/json") j, err := json.MarshalIndent(data, "", " ") if err != nil { - return err - } else { - w.Write(j) + HandleErr(w, r, err) + return false } - return nil + _, err = w.Write(j) + if err != nil { + HandleErr(w, r, err) + return false + } + return true } diff --git a/internal/api/v1/version.go b/internal/api/v1/version.go index bbfc2c5..f0db6bf 100644 --- a/internal/api/v1/version.go +++ b/internal/api/v1/version.go @@ -3,9 +3,10 @@ package v1 import ( "net/http" + . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/pkg" ) func GetVersion(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(pkg.GetVersion())) + WriteBody(w, []byte(pkg.GetVersion())) } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 63575d0..be95da0 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -16,22 +16,23 @@ import ( "github.com/go-acme/lego/v4/registration" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/types" - U "github.com/yusing/go-proxy/internal/utils" ) -type Provider struct { - cfg *Config - user *User - legoCfg *lego.Config - client *lego.Client +type ( + Provider struct { + cfg *Config + user *User + legoCfg *lego.Config + client *lego.Client - tlsCert *tls.Certificate - certExpiries CertExpiries -} + tlsCert *tls.Certificate + certExpiries CertExpiries + } + ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError) -type ProviderGenerator func(types.AutocertProviderOpt) (challenge.Provider, E.NestedError) -type CertExpiries map[string]time.Time + CertExpiries map[string]time.Time +) func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { if p.tlsCert == nil { @@ -192,8 +193,8 @@ func (p *Provider) registerACME() E.NestedError { } func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { - //* This should have been done in setup - //* but double check is always a good choice + /* This should have been done in setup + but double check is always a good choice.*/ _, err := os.Stat(path.Dir(p.cfg.CertPath)) if err != nil { if os.IsNotExist(err) { diff --git a/internal/common/constants.go b/internal/common/constants.go index 57a5093..abba4f0 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -36,14 +36,12 @@ const ( ErrorPagesBasePath = "error_pages" ) -var ( - RequiredDirectories = []string{ - ConfigBasePath, - SchemaBasePath, - ErrorPagesBasePath, - MiddlewareComposeBasePath, - } -) +var RequiredDirectories = []string{ + ConfigBasePath, + SchemaBasePath, + ErrorPagesBasePath, + MiddlewareComposeBasePath, +} const DockerHostFromEnv = "$DOCKER_HOST" diff --git a/internal/config/config.go b/internal/config/config.go index 071a524..9f40b94 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,7 +8,6 @@ import ( "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" - PR "github.com/yusing/go-proxy/internal/proxy/provider" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/types" diff --git a/internal/docker/client.go b/internal/docker/client.go index 8d2845f..074d730 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -21,29 +21,21 @@ type Client struct { l logrus.FieldLogger } -func ParseDockerHostname(host string) (string, E.NestedError) { - switch host { - case common.DockerHostFromEnv, "": - return "localhost", nil - } - url, err := E.Check(client.ParseHostURL(host)) - if err != nil { - return "", E.Invalid("host", host).With(err) - } - return url.Hostname(), nil -} +var ( + clientMap F.Map[string, Client] = F.NewMapOf[string, Client]() + clientMapMu sync.Mutex -func (c Client) DaemonHostname() string { - // DaemonHost should always return a valid host - hostname, _ := ParseDockerHostname(c.DaemonHost()) - return hostname -} + clientOptEnvHost = []client.Opt{ + client.WithHostFromEnv(), + client.WithAPIVersionNegotiation(), + } +) func (c Client) Connected() bool { return c.Client != nil } -// if the client is still referenced, this is no-op +// if the client is still referenced, this is no-op. func (c *Client) Close() error { if c.refCount.Add(-1) > 0 { return nil @@ -86,6 +78,8 @@ func ConnectClient(host string) (Client, E.NestedError) { var opt []client.Opt switch host { + case "": + return Client{}, E.Invalid("docker host", "empty") case common.DockerHostFromEnv: opt = clientOptEnvHost default: @@ -139,15 +133,3 @@ func CloseAllClients() { clientMap.Clear() logger.Debug("closed all clients") } - -var ( - clientMap F.Map[string, Client] = F.NewMapOf[string, Client]() - clientMapMu sync.Mutex - - clientOptEnvHost = []client.Opt{ - client.WithHostFromEnv(), - client.WithAPIVersionNegotiation(), - } - - logger = logrus.WithField("module", "docker") -) diff --git a/internal/docker/client_info.go b/internal/docker/client_info.go index 8446393..6228920 100644 --- a/internal/docker/client_info.go +++ b/internal/docker/client_info.go @@ -7,7 +7,6 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" - E "github.com/yusing/go-proxy/internal/error" ) diff --git a/internal/docker/container.go b/internal/docker/container.go index f670f06..c3bfaa9 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -1,45 +1,78 @@ package docker import ( - "fmt" + "net/url" "strconv" "strings" "github.com/docker/docker/api/types" + "github.com/sirupsen/logrus" U "github.com/yusing/go-proxy/internal/utils" ) -type Container struct { - *types.Container - *ProxyProperties -} +type ( + PortMapping = map[string]types.Port + Container struct { + _ U.NoCopy -func FromDocker(c *types.Container, dockerHost string) (res Container) { - res.Container = c - isExplicit := c.Labels[LabelAliases] != "" - res.ProxyProperties = &ProxyProperties{ - DockerHost: dockerHost, - ContainerName: res.getName(), - ContainerID: c.ID, - ImageName: res.getImageName(), - PublicPortMapping: res.getPublicPortMapping(), - PrivatePortMapping: res.getPrivatePortMapping(), - NetworkMode: c.HostConfig.NetworkMode, - Aliases: res.getAliases(), - IsExcluded: U.ParseBool(res.getDeleteLabel(LabelExclude)), - IsExplicit: isExplicit, - IsDatabase: res.isDatabase(), - IdleTimeout: res.getDeleteLabel(LabelIdleTimeout), - WakeTimeout: res.getDeleteLabel(LabelWakeTimeout), - StopMethod: res.getDeleteLabel(LabelStopMethod), - StopTimeout: res.getDeleteLabel(LabelStopTimeout), - StopSignal: res.getDeleteLabel(LabelStopSignal), - Running: c.Status == "running" || c.State == "running", + DockerHost string `json:"docker_host" yaml:"-"` + ContainerName string `json:"container_name" yaml:"-"` + ContainerID string `json:"container_id" yaml:"-"` + ImageName string `json:"image_name" yaml:"-"` + + Labels map[string]string `json:"labels" yaml:"-"` + + PublicPortMapping PortMapping `json:"public_ports" yaml:"-"` // non-zero publicPort:types.Port + PrivatePortMapping PortMapping `json:"private_ports" yaml:"-"` // privatePort:types.Port + PublicIP string `json:"public_ip" yaml:"-"` + PrivateIP string `json:"private_ip" yaml:"-"` + NetworkMode string `json:"network_mode" yaml:"-"` + + Aliases []string `json:"aliases" yaml:"-"` + IsExcluded bool `json:"is_excluded" yaml:"-"` + IsExplicit bool `json:"is_explicit" yaml:"-"` + IsDatabase bool `json:"is_database" yaml:"-"` + IdleTimeout string `json:"idle_timeout" yaml:"-"` + WakeTimeout string `json:"wake_timeout" yaml:"-"` + StopMethod string `json:"stop_method" yaml:"-"` + StopTimeout string `json:"stop_timeout" yaml:"-"` // stop_method = "stop" only + StopSignal string `json:"stop_signal" yaml:"-"` // stop_method = "stop" | "kill" only + Running bool `json:"running" yaml:"-"` } +) + +func FromDocker(c *types.Container, dockerHost string) (res *Container) { + isExplicit := c.Labels[LabelAliases] != "" + helper := containerHelper{c} + res = &Container{ + DockerHost: dockerHost, + ContainerName: helper.getName(), + ContainerID: c.ID, + ImageName: helper.getImageName(), + + Labels: c.Labels, + + PublicPortMapping: helper.getPublicPortMapping(), + PrivatePortMapping: helper.getPrivatePortMapping(), + NetworkMode: c.HostConfig.NetworkMode, + + Aliases: helper.getAliases(), + IsExcluded: U.ParseBool(helper.getDeleteLabel(LabelExclude)), + IsExplicit: isExplicit, + IsDatabase: helper.isDatabase(), + IdleTimeout: helper.getDeleteLabel(LabelIdleTimeout), + WakeTimeout: helper.getDeleteLabel(LabelWakeTimeout), + StopMethod: helper.getDeleteLabel(LabelStopMethod), + StopTimeout: helper.getDeleteLabel(LabelStopTimeout), + StopSignal: helper.getDeleteLabel(LabelStopSignal), + Running: c.Status == "running" || c.State == "running", + } + res.setPrivateIP(helper) + res.setPublicIP() return } -func FromJson(json types.ContainerJSON, dockerHost string) Container { +func FromJSON(json types.ContainerJSON, dockerHost string) *Container { ports := make([]types.Port, 0) for k, bindings := range json.NetworkSettings.Ports { for _, v := range bindings { @@ -65,79 +98,32 @@ func FromJson(json types.ContainerJSON, dockerHost string) Container { return cont } -func (c Container) getDeleteLabel(label string) string { - if l, ok := c.Labels[label]; ok { - delete(c.Labels, label) - return l +func (c *Container) setPublicIP() { + if c.PublicPortMapping == nil { + return } - return "" + if strings.HasPrefix(c.DockerHost, "unix://") { + c.PublicIP = "127.0.0.1" + return + } + url, err := url.Parse(c.DockerHost) + if err != nil { + logrus.Errorf("invalid docker host %q: %v\nfalling back to 127.0.0.1", c.DockerHost, err) + c.PublicIP = "127.0.0.1" + return + } + c.PublicIP = url.Hostname() } -func (c Container) getAliases() []string { - if l := c.getDeleteLabel(LabelAliases); l != "" { - return U.CommaSeperatedList(l) - } else { - return []string{c.getName()} +func (c *Container) setPrivateIP(helper containerHelper) { + if !strings.HasPrefix(c.DockerHost, "unix://") { + return + } + if helper.NetworkSettings == nil { + return + } + for _, v := range helper.NetworkSettings.Networks { + c.PrivateIP = v.IPAddress + return } } - -func (c Container) getName() string { - return strings.TrimPrefix(c.Names[0], "/") -} - -func (c Container) getImageName() string { - colonSep := strings.Split(c.Image, ":") - slashSep := strings.Split(colonSep[0], "/") - return slashSep[len(slashSep)-1] -} - -func (c Container) getPublicPortMapping() PortMapping { - res := make(PortMapping) - for _, v := range c.Ports { - if v.PublicPort == 0 { - continue - } - res[fmt.Sprint(v.PublicPort)] = v - } - return res -} - -func (c Container) getPrivatePortMapping() PortMapping { - res := make(PortMapping) - for _, v := range c.Ports { - res[fmt.Sprint(v.PrivatePort)] = v - } - return res -} - -var databaseMPs = map[string]struct{}{ - "/var/lib/postgresql/data": {}, - "/var/lib/mysql": {}, - "/var/lib/mongodb": {}, - "/var/lib/mariadb": {}, - "/var/lib/memcached": {}, - "/var/lib/rabbitmq": {}, -} - -var databasePrivPorts = map[uint16]struct{}{ - 5432: {}, // postgres - 3306: {}, // mysql, mariadb - 6379: {}, // redis - 11211: {}, // memcached - 27017: {}, // mongodb -} - -func (c Container) isDatabase() bool { - for _, m := range c.Container.Mounts { - if _, ok := databaseMPs[m.Destination]; ok { - return true - } - } - - for _, v := range c.Ports { - if _, ok := databasePrivPorts[v.PrivatePort]; ok { - return true - } - } - return false -} diff --git a/internal/docker/container_helper.go b/internal/docker/container_helper.go new file mode 100644 index 0000000..12d6cc7 --- /dev/null +++ b/internal/docker/container_helper.go @@ -0,0 +1,90 @@ +package docker + +import ( + "strings" + + "github.com/docker/docker/api/types" + U "github.com/yusing/go-proxy/internal/utils" +) + +type containerHelper struct { + *types.Container +} + +// getDeleteLabel gets the value of a label and then deletes it from the container. +// If the label does not exist, an empty string is returned. +func (c containerHelper) getDeleteLabel(label string) string { + if l, ok := c.Labels[label]; ok { + delete(c.Labels, label) + return l + } + return "" +} + +func (c containerHelper) getAliases() []string { + if l := c.getDeleteLabel(LabelAliases); l != "" { + return U.CommaSeperatedList(l) + } + return []string{c.getName()} +} + +func (c containerHelper) getName() string { + return strings.TrimPrefix(c.Names[0], "/") +} + +func (c containerHelper) getImageName() string { + colonSep := strings.Split(c.Image, ":") + slashSep := strings.Split(colonSep[0], "/") + return slashSep[len(slashSep)-1] +} + +func (c containerHelper) getPublicPortMapping() PortMapping { + res := make(PortMapping) + for _, v := range c.Ports { + if v.PublicPort == 0 { + continue + } + res[U.PortString(v.PublicPort)] = v + } + return res +} + +func (c containerHelper) getPrivatePortMapping() PortMapping { + res := make(PortMapping) + for _, v := range c.Ports { + res[U.PortString(v.PrivatePort)] = v + } + return res +} + +var databaseMPs = map[string]struct{}{ + "/var/lib/postgresql/data": {}, + "/var/lib/mysql": {}, + "/var/lib/mongodb": {}, + "/var/lib/mariadb": {}, + "/var/lib/memcached": {}, + "/var/lib/rabbitmq": {}, +} + +var databasePrivPorts = map[uint16]struct{}{ + 5432: {}, // postgres + 3306: {}, // mysql, mariadb + 6379: {}, // redis + 11211: {}, // memcached + 27017: {}, // mongodb +} + +func (c containerHelper) isDatabase() bool { + for _, m := range c.Mounts { + if _, ok := databaseMPs[m.Destination]; ok { + return true + } + } + + for _, v := range c.Ports { + if _, ok := databasePrivPorts[v.PrivatePort]; ok { + return true + } + } + return false +} diff --git a/internal/docker/idlewatcher/http.go b/internal/docker/idlewatcher/http.go index 19fbabb..bb000af 100644 --- a/internal/docker/idlewatcher/http.go +++ b/internal/docker/idlewatcher/http.go @@ -18,9 +18,9 @@ type templateData struct { var loadingPage []byte var loadingPageTmpl = template.Must(template.New("loading_page").Parse(string(loadingPage))) -const headerCheckRedirect = "X-GoProxy-Check-Redirect" +const headerCheckRedirect = "X-Goproxy-Check-Redirect" -func (w *watcher) makeRespBody(format string, args ...any) []byte { +func (w *Watcher) makeRespBody(format string, args ...any) []byte { msg := fmt.Sprintf(format, args...) data := new(templateData) diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 9e00174..da9f94e 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -11,13 +11,13 @@ import ( ) type Waker struct { - *watcher + *Watcher client *http.Client rp *gphttp.ReverseProxy } -func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker { +func NewWaker(w *Watcher, rp *gphttp.ReverseProxy) *Waker { orig := rp.ServeHTTP // workaround for stopped containers port become zero rp.ServeHTTP = func(rw http.ResponseWriter, r *http.Request) { @@ -33,7 +33,7 @@ func NewWaker(w *watcher, rp *gphttp.ReverseProxy) *Waker { orig(rw, r) } return &Waker{ - watcher: w, + Watcher: w, client: &http.Client{ Timeout: 1 * time.Second, Transport: rp.Transport, @@ -70,7 +70,9 @@ func (w *Waker) wake(next http.HandlerFunc, rw http.ResponseWriter, r *http.Requ rw.Header().Add("Cache-Control", "no-cache") rw.Header().Add("Cache-Control", "no-store") rw.Header().Add("Cache-Control", "must-revalidate") - rw.Write(body) + if _, err := rw.Write(body); err != nil { + w.l.Errorf("error writing http response: %s", err) + } return } diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index c27cb0e..98c0ec8 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -17,7 +17,7 @@ import ( ) type ( - watcher struct { + Watcher struct { *P.ReverseProxyEntry client D.Client @@ -46,17 +46,17 @@ var ( mainLoopCancel context.CancelFunc mainLoopWg sync.WaitGroup - watcherMap = F.NewMapOf[string, *watcher]() + watcherMap = F.NewMapOf[string, *Watcher]() watcherMapMu sync.Mutex portHistoryMap = F.NewMapOf[PT.Alias, string]() - newWatcherCh = make(chan *watcher) + newWatcherCh = make(chan *Watcher) logger = logrus.WithField("module", "idle_watcher") ) -func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { +func Register(entry *P.ReverseProxyEntry) (*Watcher, E.NestedError) { failure := E.Failure("idle_watcher register") if entry.IdleTimeout == 0 { @@ -83,7 +83,7 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { return nil, failure.With(err) } - w := &watcher{ + w := &Watcher{ ReverseProxyEntry: entry, client: client, refCount: &sync.WaitGroup{}, @@ -104,7 +104,7 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { return w, nil } -func (w *watcher) Unregister() { +func (w *Watcher) Unregister() { w.refCount.Add(-1) } @@ -138,29 +138,30 @@ func Stop() { mainLoopWg.Wait() } -func (w *watcher) containerStop() error { +func (w *Watcher) containerStop() error { return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{ Signal: string(w.StopSignal), - Timeout: &w.StopTimeout}) + Timeout: &w.StopTimeout, + }) } -func (w *watcher) containerPause() error { +func (w *Watcher) containerPause() error { return w.client.ContainerPause(w.ctx, w.ContainerID) } -func (w *watcher) containerKill() error { +func (w *Watcher) containerKill() error { return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal)) } -func (w *watcher) containerUnpause() error { +func (w *Watcher) containerUnpause() error { return w.client.ContainerUnpause(w.ctx, w.ContainerID) } -func (w *watcher) containerStart() error { +func (w *Watcher) containerStart() error { return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{}) } -func (w *watcher) containerStatus() (string, E.NestedError) { +func (w *Watcher) containerStatus() (string, E.NestedError) { json, err := w.client.ContainerInspect(w.ctx, w.ContainerID) if err != nil { return "", E.FailWith("inspect container", err) @@ -168,7 +169,7 @@ func (w *watcher) containerStatus() (string, E.NestedError) { return json.State.Status, nil } -func (w *watcher) wakeIfStopped() E.NestedError { +func (w *Watcher) wakeIfStopped() E.NestedError { if w.ready.Load() || w.ContainerRunning { return nil } @@ -191,7 +192,7 @@ func (w *watcher) wakeIfStopped() E.NestedError { } } -func (w *watcher) getStopCallback() StopCallback { +func (w *Watcher) getStopCallback() StopCallback { var cb func() error switch w.StopMethod { case PT.StopMethodPause: @@ -215,11 +216,11 @@ func (w *watcher) getStopCallback() StopCallback { } } -func (w *watcher) resetIdleTimer() { +func (w *Watcher) resetIdleTimer() { w.ticker.Reset(w.IdleTimeout) } -func (w *watcher) watchUntilCancel() { +func (w *Watcher) watchUntilCancel() { defer close(w.wakeCh) w.ctx, w.cancel = context.WithCancel(mainLoopCtx) diff --git a/internal/docker/inspect.go b/internal/docker/inspect.go index 7ee8e5f..fcafe77 100644 --- a/internal/docker/inspect.go +++ b/internal/docker/inspect.go @@ -7,13 +7,13 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -func (c Client) Inspect(containerID string) (Container, E.NestedError) { +func (c Client) Inspect(containerID string) (*Container, E.NestedError) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() json, err := c.ContainerInspect(ctx, containerID) if err != nil { - return Container{}, E.From(err) + return nil, E.From(err) } - return FromJson(json, c.key), nil + return FromJSON(json, c.key), nil } diff --git a/internal/docker/label.go b/internal/docker/label.go index bea124a..547186e 100644 --- a/internal/docker/label.go +++ b/internal/docker/label.go @@ -47,7 +47,7 @@ func ApplyLabel[T any](obj *T, l *Label) E.NestedError { case *Label: var field reflect.Value objType := reflect.TypeFor[T]() - for i := 0; i < reflect.TypeFor[T]().NumField(); i++ { + for i := range reflect.TypeFor[T]().NumField() { if objType.Field(i).Tag.Get("yaml") == l.Attribute { field = reflect.ValueOf(obj).Elem().Field(i) break diff --git a/internal/docker/label_test.go b/internal/docker/label_test.go index c2588be..3f178ea 100644 --- a/internal/docker/label_test.go +++ b/internal/docker/label_test.go @@ -8,14 +8,18 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) +const ( + mName = "middleware1" + mAttr = "prop1" + v = "value1" +) + func makeLabel(ns, name, attr string) string { return fmt.Sprintf("%s.%s.%s", ns, name, attr) } func TestNestedLabel(t *testing.T) { - mName := "middleware1" mAttr := "prop1" - v := "value1" pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) ExpectNoError(t, err.Error()) sGot := ExpectType[*Label](t, pl.Value) @@ -28,9 +32,6 @@ func TestApplyNestedLabel(t *testing.T) { entry := new(struct { Middlewares NestedLabelMap `yaml:"middlewares"` }) - mName := "middleware1" - mAttr := "prop1" - v := "value1" pl, err := ParseLabel(makeLabel(NSProxy, "foo", makeLabel("middlewares", mName, mAttr)), v) ExpectNoError(t, err.Error()) err = ApplyLabel(entry, pl) @@ -42,10 +43,6 @@ func TestApplyNestedLabel(t *testing.T) { } func TestApplyNestedLabelExisting(t *testing.T) { - mName := "middleware1" - mAttr := "prop1" - v := "value1" - checkAttr := "prop2" checkV := "value2" entry := new(struct { @@ -71,9 +68,6 @@ func TestApplyNestedLabelExisting(t *testing.T) { } func TestApplyNestedLabelNoAttr(t *testing.T) { - mName := "middleware1" - v := "value1" - entry := new(struct { Middlewares NestedLabelMap `yaml:"middlewares"` }) diff --git a/internal/docker/logger.go b/internal/docker/logger.go new file mode 100644 index 0000000..b86d1be --- /dev/null +++ b/internal/docker/logger.go @@ -0,0 +1,5 @@ +package docker + +import "github.com/sirupsen/logrus" + +var logger = logrus.WithField("module", "docker") diff --git a/internal/docker/proxy_properties.go b/internal/docker/proxy_properties.go deleted file mode 100644 index 3dc2055..0000000 --- a/internal/docker/proxy_properties.go +++ /dev/null @@ -1,25 +0,0 @@ -package docker - -import "github.com/docker/docker/api/types" - -type PortMapping = map[string]types.Port -type ProxyProperties struct { - DockerHost string `yaml:"-" json:"docker_host"` - ContainerName string `yaml:"-" json:"container_name"` - ContainerID string `yaml:"-" json:"container_id"` - ImageName string `yaml:"-" json:"image_name"` - PublicPortMapping PortMapping `yaml:"-" json:"public_port_mapping"` // non-zero publicPort:types.Port - PrivatePortMapping PortMapping `yaml:"-" json:"private_port_mapping"` // privatePort:types.Port - NetworkMode string `yaml:"-" json:"network_mode"` - - Aliases []string `yaml:"-" json:"aliases"` - IsExcluded bool `yaml:"-" json:"is_excluded"` - IsExplicit bool `yaml:"-" json:"is_explicit"` - IsDatabase bool `yaml:"-" json:"is_database"` - IdleTimeout string `yaml:"-" json:"idle_timeout"` - WakeTimeout string `yaml:"-" json:"wake_timeout"` - StopMethod string `yaml:"-" json:"stop_method"` - StopTimeout string `yaml:"-" json:"stop_timeout"` // stop_method = "stop" only - StopSignal string `yaml:"-" json:"stop_signal"` // stop_method = "stop" | "kill" only - Running bool `yaml:"-" json:"running"` -} diff --git a/internal/error/builder.go b/internal/error/builder.go index 4528de7..1fc8a63 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -21,7 +21,7 @@ func NewBuilder(format string, args ...any) Builder { } // adding nil / nil is no-op, -// you may safely pass expressions returning error to it +// you may safely pass expressions returning error to it. func (b Builder) Add(err NestedError) Builder { if err != nil { b.Lock() @@ -39,6 +39,13 @@ func (b Builder) Addf(format string, args ...any) Builder { return b.Add(errorf(format, args...)) } +func (b Builder) AddRangeE(errs ...error) Builder { + for _, err := range errs { + b.AddE(err) + } + return b +} + // Build builds a NestedError based on the errors collected in the Builder. // // If there are no errors in the Builder, it returns a Nil() NestedError. @@ -56,12 +63,13 @@ func (b Builder) Build() NestedError { } func (b Builder) To(ptr *NestedError) { - if ptr == nil { + switch { + case ptr == nil: return - } else if *ptr == nil { + case *ptr == nil: *ptr = b.Build() - } else { - (*ptr).With(b.Build()) + default: + (*ptr).extras = append((*ptr).extras, *b.Build()) } } diff --git a/internal/error/builder_test.go b/internal/error/builder_test.go index de7ce06..9d9011b 100644 --- a/internal/error/builder_test.go +++ b/internal/error/builder_test.go @@ -33,15 +33,13 @@ func TestBuilderNested(t *testing.T) { eb.Add(Failure("Action 2").With(Invalid("Inner", "3"))) got := eb.Build().String() - expected1 := - (`error occurred: + expected1 := (`error occurred: - Action 1 failed: - invalid Inner: 1 - invalid Inner: 2 - Action 2 failed: - invalid Inner: 3`) - expected2 := - (`error occurred: + expected2 := (`error occurred: - Action 1 failed: - invalid Inner: "1" - invalid Inner: "2" diff --git a/internal/error/error.go b/internal/error/error.go index 33627cb..56ae7cc 100644 --- a/internal/error/error.go +++ b/internal/error/error.go @@ -8,16 +8,16 @@ import ( ) type ( - NestedError = *nestedError - nestedError struct { + NestedError = *NestedErrorImpl + NestedErrorImpl struct { subject string err error - extras []nestedError + extras []NestedErrorImpl } - jsonNestedError struct { - Subject string - Err string - Extras []jsonNestedError + JSONNestedError struct { + Subject string `json:"subject"` + Err string `json:"error"` + Extras []JSONNestedError `json:"extras,omitempty"` } ) @@ -25,18 +25,18 @@ func From(err error) NestedError { if IsNil(err) { return nil } - return &nestedError{err: err} + return &NestedErrorImpl{err: err} } func FromJSON(data []byte) (NestedError, bool) { - var j jsonNestedError + var j JSONNestedError if err := json.Unmarshal(data, &j); err != nil { return nil, false } if j.Err == "" { return nil, false } - extras := make([]nestedError, len(j.Extras)) + extras := make([]NestedErrorImpl, len(j.Extras)) for i, e := range j.Extras { extra, ok := fromJSONObject(e) if !ok { @@ -44,7 +44,7 @@ func FromJSON(data []byte) (NestedError, bool) { } extras[i] = *extra } - return &nestedError{ + return &NestedErrorImpl{ subject: j.Subject, err: errors.New(j.Err), extras: extras, @@ -58,26 +58,26 @@ func Check[T any](obj T, err error) (T, NestedError) { } func Join(message string, err ...NestedError) NestedError { - extras := make([]nestedError, len(err)) + extras := make([]NestedErrorImpl, len(err)) nErr := 0 for i, e := range err { if e == nil { continue } extras[i] = *e - nErr += 1 + nErr++ } if nErr == 0 { return nil } - return &nestedError{ + return &NestedErrorImpl{ err: errors.New(message), extras: extras, } } func JoinE(message string, err ...error) NestedError { - b := NewBuilder(message) + b := NewBuilder("%s", message) for _, e := range err { b.AddE(e) } @@ -151,7 +151,7 @@ func (ne NestedError) Extraf(format string, args ...any) NestedError { return ne.With(errorf(format, args...)) } -func (ne NestedError) Subject(s any) NestedError { +func (ne NestedError) Subject(s any, sep ...string) NestedError { if ne == nil { return ne } @@ -164,11 +164,12 @@ func (ne NestedError) Subject(s any) NestedError { default: subject = fmt.Sprint(s) } - if ne.subject == "" { + switch { + case ne.subject == "": ne.subject = subject - } else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') { - ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject) - } else { + case len(sep) > 0: + ne.subject = fmt.Sprintf("%s%s%s", subject, sep[0], ne.subject) + default: ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject) } return ne @@ -178,21 +179,15 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError { if ne == nil { return ne } - if strings.Contains(format, "%q") { - panic("Subjectf format should not contain %q") - } - if strings.Contains(format, "%w") { - panic("Subjectf format should not contain %w") - } return ne.Subject(fmt.Sprintf(format, args...)) } -func (ne NestedError) JSONObject() jsonNestedError { - extras := make([]jsonNestedError, len(ne.extras)) +func (ne NestedError) JSONObject() JSONNestedError { + extras := make([]JSONNestedError, len(ne.extras)) for i, e := range ne.extras { extras[i] = e.JSONObject() } - return jsonNestedError{ + return JSONNestedError{ Subject: ne.subject, Err: ne.err.Error(), Extras: extras, @@ -200,7 +195,10 @@ func (ne NestedError) JSONObject() jsonNestedError { } func (ne NestedError) JSON() []byte { - b, _ := json.MarshalIndent(ne.JSONObject(), "", " ") + b, err := json.MarshalIndent(ne.JSONObject(), "", " ") + if err != nil { + panic(err) + } return b } @@ -216,7 +214,7 @@ func errorf(format string, args ...any) NestedError { return From(fmt.Errorf(format, args...)) } -func fromJSONObject(obj jsonNestedError) (NestedError, bool) { +func fromJSONObject(obj JSONNestedError) (NestedError, bool) { data, err := json.Marshal(obj) if err != nil { return nil, false @@ -240,7 +238,7 @@ func (ne NestedError) appendMsg(msg string) NestedError { } func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { - for i := 0; i < level; i++ { + for range level { sb.WriteString(" ") } sb.WriteString(prefix) @@ -267,7 +265,7 @@ func (ne NestedError) buildError(level int, prefix string) error { var res error var sb strings.Builder - for i := 0; i < level; i++ { + for range level { sb.WriteString(" ") } sb.WriteString(prefix) diff --git a/internal/error/error_test.go b/internal/error/error_test.go index 4c08a3f..844f4bf 100644 --- a/internal/error/error_test.go +++ b/internal/error/error_test.go @@ -1,10 +1,9 @@ -package error_test +package error import ( "errors" "testing" - . "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -88,8 +87,7 @@ func TestErrorNested(t *testing.T) { With("baz"). With(inner). With(inner.With(inner2.With(inner3))) - want := - `foo failed: + want := `foo failed: - bar - baz - inner failed: diff --git a/internal/list-icons.go b/internal/list-icons.go index f4b9c84..1764901 100644 --- a/internal/list-icons.go +++ b/internal/list-icons.go @@ -4,12 +4,11 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "os" "time" - "log" - "github.com/yusing/go-proxy/internal/utils" ) @@ -21,8 +20,10 @@ type GitHubContents struct { //! keep this, may reuse in future Size int `json:"size"` } -const iconsCachePath = "/tmp/icons_cache.json" -const updateInterval = 1 * time.Hour +const ( + iconsCachePath = "/tmp/icons_cache.json" + updateInterval = 1 * time.Hour +) func ListAvailableIcons() ([]string, error) { owner := "walkxcode" @@ -30,13 +31,14 @@ func ListAvailableIcons() ([]string, error) { ref := "main" var lastUpdate time.Time - var icons = make([]string, 0) + + icons := make([]string, 0) info, err := os.Stat(iconsCachePath) if err == nil { lastUpdate = info.ModTime().Local() } if time.Since(lastUpdate) < updateInterval { - err := utils.LoadJson(iconsCachePath, &icons) + err := utils.LoadJSON(iconsCachePath, &icons) if err == nil { return icons, nil } @@ -51,7 +53,7 @@ func ListAvailableIcons() ([]string, error) { icons = append(icons, content.Path) } } - err = utils.SaveJson(iconsCachePath, &icons, 0o644).Error() + err = utils.SaveJSON(iconsCachePath, &icons, 0o644).Error() if err != nil { log.Print("error saving cache", err) } @@ -59,7 +61,7 @@ func ListAvailableIcons() ([]string, error) { } func getRepoContents(client *http.Client, owner string, repo string, ref string, path string) ([]GitHubContents, error) { - req, err := http.NewRequest("GET", fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/%s?ref=%s", owner, repo, path, ref), nil) if err != nil { return nil, err } diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index 466083f..c298436 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -10,8 +10,8 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -// TODO: stats of each server -// TODO: support weighted mode +// TODO: stats of each server. +// TODO: support weighted mode. type ( impl interface { ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index f875c76..f06b686 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -2,14 +2,14 @@ package middleware import ( "bytes" - "fmt" "io" "net/http" "path/filepath" + "strconv" "strings" "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/api/v1/error_page" + "github.com/yusing/go-proxy/internal/api/v1/errorpage" gphttp "github.com/yusing/go-proxy/internal/net/http" ) @@ -23,14 +23,15 @@ var CustomErrorPage = &Middleware{ // only handles non-success status code and html/plain content type contentType := gphttp.GetContentType(resp.Header) if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { - errorPage, ok := error_page.GetErrorPageByStatus(resp.StatusCode) + errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) if ok { errPageLogger.Debugf("error page for status %d loaded", resp.StatusCode) + /* trunk-ignore(golangci-lint/errcheck) */ io.Copy(io.Discard, resp.Body) // drain the original body resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.ContentLength = int64(len(errorPage)) - resp.Header.Set("Content-Length", fmt.Sprint(len(errorPage))) + resp.Header.Set("Content-Length", strconv.Itoa(len(errorPage))) resp.Header.Set("Content-Type", "text/html; charset=utf-8") } else { errPageLogger.Errorf("unable to load error page for status %d", resp.StatusCode) @@ -48,25 +49,27 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) bool { } if strings.HasPrefix(path, gphttp.StaticFilePathPrefix) { filename := path[len(gphttp.StaticFilePathPrefix):] - file, ok := error_page.GetStaticFile(filename) + file, ok := errorpage.GetStaticFile(filename) if !ok { errPageLogger.Errorf("unable to load resource %s", filename) return false - } else { - ext := filepath.Ext(filename) - switch ext { - case ".html": - w.Header().Set("Content-Type", "text/html; charset=utf-8") - case ".js": - w.Header().Set("Content-Type", "application/javascript; charset=utf-8") - case ".css": - w.Header().Set("Content-Type", "text/css; charset=utf-8") - default: - errPageLogger.Errorf("unexpected file type %q for %s", ext, filename) - } - w.Write(file) - return true } + ext := filepath.Ext(filename) + switch ext { + case ".html": + w.Header().Set("Content-Type", "text/html; charset=utf-8") + case ".js": + w.Header().Set("Content-Type", "application/javascript; charset=utf-8") + case ".css": + w.Header().Set("Content-Type", "text/css; charset=utf-8") + default: + errPageLogger.Errorf("unexpected file type %q for %s", ext, filename) + } + if _, err := w.Write(file); err != nil { + errPageLogger.WithError(err).Errorf("unable to write resource %s", filename) + http.Error(w, "Error page failure", http.StatusInternalServerError) + } + return true } return false } diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index 8b23ca8..5defadf 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -30,7 +30,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, } middlewares = make(map[string]*Middleware) for name, defs := range rawMap { - chainErr := E.NewBuilder(name) + chainErr := E.NewBuilder("%s", name) chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { if def["use"] == nil || def["use"] == "" { @@ -64,7 +64,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, return } -// TODO: check conflict or duplicates +// TODO: check conflict or duplicates. func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware { m := &Middleware{name: name, children: chain} diff --git a/internal/net/http/middleware/oauth2.go b/internal/net/http/middleware/oauth2.go index 8b53804..6b352fa 100644 --- a/internal/net/http/middleware/oauth2.go +++ b/internal/net/http/middleware/oauth2.go @@ -92,18 +92,18 @@ func userIsAuthenticated(r *http.Request) bool { return true } -func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) { +func exchangeCodeForToken(code string, opts *oAuth2Opts, requestURI string) (string, error) { // Prepare the request body data := url.Values{ "client_id": {opts.ClientID}, "client_secret": {opts.ClientSecret}, "code": {code}, "grant_type": {"authorization_code"}, - "redirect_uri": {requestUri}, + "redirect_uri": {requestURI}, } resp, err := http.PostForm(opts.TokenURL, data) if err != nil { - return "", fmt.Errorf("failed to request token: %v", err) + return "", fmt.Errorf("failed to request token: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -114,7 +114,7 @@ func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (str AccessToken string `json:"access_token"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", fmt.Errorf("failed to decode token response: %v", err) + return "", fmt.Errorf("failed to decode token response: %w", err) } return tokenResp.AccessToken, nil } diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index 47707c8..ba80b81 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -74,7 +74,7 @@ type testArgs struct { func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) { var body io.Reader - var rr = new(requestRecorder) + var rr requestRecorder var proxyURL *url.URL var requestTarget string var err error @@ -87,11 +87,14 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N body = bytes.NewReader(args.body) } - if args.scheme == "" || args.scheme == "http" { + switch args.scheme { + case "": + fallthrough + case "http": requestTarget = "http://" + testHost - } else if args.scheme == "https" { + case "https": requestTarget = "https://" + testHost - } else { + default: panic("typo?") } @@ -111,7 +114,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N } else { proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect } - rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr) + rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), &rr) mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index 4371990..ca242fa 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -24,10 +24,9 @@ import ( "sync" "github.com/sirupsen/logrus" - "golang.org/x/net/http/httpguts" - "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" + "golang.org/x/net/http/httpguts" ) // A ProxyRequest contains a request to be rewritten by a [ReverseProxy]. @@ -222,6 +221,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport ctx := req.Context() + /* trunk-ignore(golangci-lint/revive) */ if ctx.Done() != nil { // CloseNotifier predates context.Context, and has been // entirely superseded by it. If the request contains @@ -460,7 +460,7 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { - p.errorHandler(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"), true) + p.errorHandler(rw, req, errors.New("internal error: 101 switching protocols response with non-writable body"), true) return } @@ -494,21 +494,24 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R res.Header = rw.Header() res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above if err := res.Write(brw); err != nil { + /* trunk-ignore(golangci-lint/errorlint) */ p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true) return } if err := brw.Flush(); err != nil { + /* trunk-ignore(golangci-lint/errorlint) */ p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true) return } bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn) + /* trunk-ignore(golangci-lint/errcheck) */ bdp.Start() } func IsPrint(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] < ' ' || s[i] > '~' { + for _, r := range s { + if r < ' ' || r > '~' { return false } } diff --git a/internal/proxy/fields/host.go b/internal/proxy/fields/host.go index ca4a0f1..446c051 100644 --- a/internal/proxy/fields/host.go +++ b/internal/proxy/fields/host.go @@ -4,8 +4,10 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -type Host string -type Subdomain = Alias +type ( + Host string + Subdomain = Alias +) func ValidateHost[String ~string](s String) (Host, E.NestedError) { return Host(s), nil diff --git a/internal/proxy/fields/path_pattern.go b/internal/proxy/fields/path_pattern.go index 071b677..5f9b839 100644 --- a/internal/proxy/fields/path_pattern.go +++ b/internal/proxy/fields/path_pattern.go @@ -6,8 +6,12 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -type PathPattern string -type PathPatterns = []PathPattern +type ( + PathPattern string + PathPatterns = []PathPattern +) + +var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) func NewPathPattern(s string) (PathPattern, E.NestedError) { if len(s) == 0 { @@ -25,13 +29,11 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) { } pp := make(PathPatterns, len(s)) for i, v := range s { - if pattern, err := NewPathPattern(v); err.HasError() { + pattern, err := NewPathPattern(v) + if err != nil { return nil, err - } else { - pp[i] = pattern } + pp[i] = pattern } return pp, nil } - -var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) diff --git a/internal/proxy/provider/docker.go b/internal/proxy/provider/docker.go index 35aa909..601dc44 100755 --- a/internal/proxy/provider/docker.go +++ b/internal/proxy/provider/docker.go @@ -1,15 +1,15 @@ package provider import ( - "fmt" "regexp" "strconv" "strings" + "github.com/docker/docker/client" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" - R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/types" W "github.com/yusing/go-proxy/internal/watcher" @@ -17,23 +17,24 @@ import ( ) type DockerProvider struct { - name, dockerHost, hostname string - ExplicitOnly bool + name, dockerHost string + ExplicitOnly bool } -var AliasRefRegex = regexp.MustCompile(`#\d+`) -var AliasRefRegexOld = regexp.MustCompile(`\$\d+`) +var ( + AliasRefRegex = regexp.MustCompile(`#\d+`) + AliasRefRegexOld = regexp.MustCompile(`\$\d+`) +) func DockerProviderImpl(name, dockerHost string, explicitOnly bool) (ProviderImpl, E.NestedError) { - hostname, err := D.ParseDockerHostname(dockerHost) - if err.HasError() { - return nil, err + if dockerHost == common.DockerHostFromEnv { + dockerHost = common.GetEnv("DOCKER_HOST", client.DefaultDockerHost) } - return &DockerProvider{name, dockerHost, hostname, explicitOnly}, nil + return &DockerProvider{name, dockerHost, explicitOnly}, nil } func (p *DockerProvider) String() string { - return fmt.Sprintf("docker: %s", p.name) + return "docker: " + p.name } func (p *DockerProvider) NewWatcher() W.Watcher { @@ -49,7 +50,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { return routes, E.FailWith("connect to docker", err) } - errors := E.NewBuilder("errors when parse docker labels") + errors := E.NewBuilder("errors in docker labels") for _, c := range info.Containers { container := D.FromDocker(&c, p.dockerHost) @@ -80,7 +81,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { return routes, errors.Build() } -func (p *DockerProvider) shouldIgnore(container D.Container) bool { +func (p *DockerProvider) shouldIgnore(container *D.Container) bool { return container.IsExcluded || !container.IsExplicit && p.ExplicitOnly || !container.IsExplicit && container.IsDatabase || @@ -172,8 +173,8 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul } // Returns a list of proxy entries for a container. -// Always non-nil -func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (entries types.RawEntries, _ E.NestedError) { +// Always non-nil. +func (p *DockerProvider) entriesFromContainerLabels(container *D.Container) (entries types.RawEntries, _ E.NestedError) { entries = types.NewProxyEntries() if p.shouldIgnore(container) { @@ -183,9 +184,8 @@ func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (entr // init entries map for all aliases for _, a := range container.Aliases { entries.Store(a, &types.RawEntry{ - Alias: a, - Host: p.hostname, - ProxyProperties: container.ProxyProperties, + Alias: a, + Container: container, }) } @@ -202,11 +202,11 @@ func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (entr return entries, errors.Build().Subject(container.ContainerName) } -func (p *DockerProvider) applyLabel(container D.Container, entries types.RawEntries, key, val string) (res E.NestedError) { +func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEntries, key, val string) (res E.NestedError) { b := E.NewBuilder("errors in label %s", key) defer b.To(&res) - refErr := E.NewBuilder("errors parsing alias references") + refErr := E.NewBuilder("errors in alias references") replaceIndexRef := func(ref string) string { index, err := strconv.Atoi(ref[1:]) if err != nil { @@ -231,7 +231,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries types.RawEntr // apply label for all aliases entries.RangeAll(func(a string, e *types.RawEntry) { if err = D.ApplyLabel(e, lbl); err.HasError() { - b.Add(err.Subjectf("alias %s", lbl.Target)) + b.Add(err) } }) } else { @@ -250,7 +250,7 @@ func (p *DockerProvider) applyLabel(container D.Container, entries types.RawEntr return } if err = D.ApplyLabel(config, lbl); err.HasError() { - b.Add(err.Subjectf("alias %s", lbl.Target)) + b.Add(err) } } return diff --git a/internal/proxy/provider/docker_test.go b/internal/proxy/provider/docker_test.go index 52b6c9e..6618e2f 100644 --- a/internal/proxy/provider/docker_test.go +++ b/internal/proxy/provider/docker_test.go @@ -10,7 +10,6 @@ import ( E "github.com/yusing/go-proxy/internal/error" P "github.com/yusing/go-proxy/internal/proxy" T "github.com/yusing/go-proxy/internal/proxy/fields" - . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -60,7 +59,8 @@ func TestApplyLabelFieldValidity(t *testing.T) { }, Ports: []types.Port{ {Type: "tcp", PrivatePort: 4567, PublicPort: 8888}, - }}, "")) + }, + }, "")) ExpectNoError(t, err.Error()) a, ok := entries.Load("a") @@ -116,8 +116,8 @@ func TestApplyLabel(t *testing.T) { Ports: []types.Port{ {Type: "tcp", PrivatePort: 3333, PublicPort: 1111}, {Type: "tcp", PrivatePort: 4444, PublicPort: 1234}, - }}, "", - )) + }, + }, "")) a, ok := entries.Load("a") ExpectTrue(t, ok) b, ok := entries.Load("b") @@ -152,7 +152,8 @@ func TestApplyLabelWithRef(t *testing.T) { {Type: "tcp", PrivatePort: 3333, PublicPort: 9999}, {Type: "tcp", PrivatePort: 4444, PublicPort: 5555}, {Type: "tcp", PrivatePort: 1111, PublicPort: 2222}, - }}, "")) + }, + }, "")) a, ok := entries.Load("a") ExpectTrue(t, ok) b, ok := entries.Load("b") @@ -171,13 +172,14 @@ func TestApplyLabelWithRef(t *testing.T) { func TestApplyLabelWithRefIndexError(t *testing.T) { var p DockerProvider - var c = D.FromDocker(&types.Container{ + c := D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ D.LabelAliases: "a,b", "proxy.#1.host": "localhost", "proxy.#4.scheme": "https", - }}, "") + }, + }, "") _, err := p.entriesFromContainerLabels(c) ExpectError(t, E.ErrOutOfRange, err.Error()) ExpectTrue(t, strings.Contains(err.String(), "index out of range")) @@ -187,14 +189,15 @@ func TestApplyLabelWithRefIndexError(t *testing.T) { Labels: map[string]string{ D.LabelAliases: "a,b", "proxy.#0.host": "localhost", - }}, "")) + }, + }, "")) ExpectError(t, E.ErrOutOfRange, err.Error()) ExpectTrue(t, strings.Contains(err.String(), "index out of range")) } func TestStreamDefaultValues(t *testing.T) { var p DockerProvider - var c = D.FromDocker(&types.Container{ + c := D.FromDocker(&types.Container{ Names: dummyNames, Labels: map[string]string{ D.LabelAliases: "a", @@ -202,7 +205,8 @@ func TestStreamDefaultValues(t *testing.T) { }, Ports: []types.Port{ {Type: "udp", PrivatePort: 1234, PublicPort: 5678}, - }}, "", + }, + }, "", ) entries, err := p.entriesFromContainerLabels(c) ExpectNoError(t, err.Error()) @@ -228,7 +232,8 @@ func TestExplicitExclude(t *testing.T) { D.LabelAliases: "a", D.LabelExclude: "true", "proxy.a.no_tls_verify": "true", - }}, "")) + }, + }, "")) ExpectNoError(t, err.Error()) _, ok := entries.Load("a") diff --git a/internal/proxy/provider/provider.go b/internal/proxy/provider/provider.go index 322789b..4585f36 100644 --- a/internal/proxy/provider/provider.go +++ b/internal/proxy/provider/provider.go @@ -99,31 +99,21 @@ func (p *Provider) GetType() ProviderType { return p.t } -// to work with json marshaller +// to work with json marshaller. func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } func (p *Provider) StartAllRoutes() (res E.NestedError) { - errors := E.NewBuilder("errors in routes") + errors := E.NewBuilder("errors starting routes") defer errors.To(&res) // start watcher no matter load success or not go p.watchEvents() - nStarted := 0 - nFailed := 0 - p.routes.RangeAllParallel(func(alias string, r R.Route) { - if err := r.Start(); err.HasError() { - errors.Add(err.Subject(r)) - nFailed++ - } else { - nStarted++ - } + errors.Add(r.Start().Subject(r)) }) - - p.l.Debugf("%d routes started, %d failed", nStarted, nFailed) return } @@ -133,20 +123,12 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) { p.watcherCancel = nil } - errors := E.NewBuilder("errors stopping routes for provider %q", p.name) + errors := E.NewBuilder("errors stopping routes") defer errors.To(&res) - nStopped := 0 - nFailed := 0 p.routes.RangeAllParallel(func(alias string, r R.Route) { - if err := r.Stop(); err.HasError() { - errors.Add(err.Subject(r)) - nFailed++ - } else { - nStopped++ - } + errors.Add(r.Stop().Subject(r)) }) - p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed) return } @@ -165,6 +147,9 @@ func (p *Provider) LoadRoutes() E.NestedError { p.l.Infof("loaded %d routes", p.routes.Size()) return err } + if err == nil { + return nil + } return E.FailWith("loading routes", err) } diff --git a/internal/route/constants.go b/internal/route/constants.go index 46554b5..cffe105 100644 --- a/internal/route/constants.go +++ b/internal/route/constants.go @@ -4,5 +4,7 @@ import ( "time" ) -const udpBufferSize = 8192 -const streamStopListenTimeout = 1 * time.Second +const ( + udpBufferSize = 8192 + streamStopListenTimeout = 1 * time.Second +) diff --git a/internal/route/http.go b/internal/route/http.go index fff7dd3..76719e4 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -1,17 +1,17 @@ package route import ( + "errors" "fmt" - "sync" - "net/http" "strings" + "sync" "github.com/sirupsen/logrus" - "github.com/yusing/go-proxy/internal/api/v1/error_page" + "github.com/yusing/go-proxy/internal/api/v1/errorpage" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" - . "github.com/yusing/go-proxy/internal/net/http" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/middleware" P "github.com/yusing/go-proxy/internal/proxy" @@ -26,13 +26,13 @@ type ( server *loadbalancer.Server handler http.Handler - rp *ReverseProxy + rp *gphttp.ReverseProxy } SubdomainKey = PT.Alias ReverseProxyHandler struct { - *ReverseProxy + *gphttp.ReverseProxy } ) @@ -41,7 +41,7 @@ var ( httpRoutes = F.NewMapOf[string, *HTTPRoute]() httpRoutesMu sync.Mutex - globalMux = http.NewServeMux() // TODO: support regex subdomain matching + // globalMux = http.NewServeMux() // TODO: support regex subdomain matching. ) func (rp ReverseProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -60,12 +60,12 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { var trans *http.Transport if entry.NoTLSVerify { - trans = DefaultTransportNoTLS.Clone() + trans = gphttp.DefaultTransportNoTLS.Clone() } else { - trans = DefaultTransport.Clone() + trans = gphttp.DefaultTransport.Clone() } - rp := NewReverseProxy(entry.URL, trans) + rp := gphttp.NewReverseProxy(entry.URL, trans) if len(entry.Middlewares) > 0 { err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares) @@ -122,7 +122,7 @@ func (r *HTTPRoute) Start() E.NestedError { } var lb *loadbalancer.LoadBalancer - linked, ok := httpRoutes.Load(string(r.LoadBalance.Link)) + linked, ok := httpRoutes.Load(r.LoadBalance.Link) if ok { lb = linked.LoadBalancer } else { @@ -132,7 +132,7 @@ func (r *HTTPRoute) Start() E.NestedError { LoadBalancer: lb, handler: lb, } - httpRoutes.Store(string(r.LoadBalance.Link), linked) + httpRoutes.Store(r.LoadBalance.Link, linked) } r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler) lb.AddServer(r.server) @@ -152,12 +152,12 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) { } if r.server != nil { - linked, ok := httpRoutes.Load(string(r.LoadBalance.Link)) + linked, ok := httpRoutes.Load(r.LoadBalance.Link) if ok { linked.LoadBalancer.RemoveServer(r.server) } if linked.LoadBalancer.IsEmpty() { - httpRoutes.Delete(string(r.LoadBalance.Link)) + httpRoutes.Delete(r.LoadBalance.Link) } r.server = nil } else { @@ -180,11 +180,13 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) { logrus.Error(E.Failure("request"). Subjectf("%s %s", r.Method, r.URL.String()). With(err)) - errorPage, ok := error_page.GetErrorPageByStatus(http.StatusNotFound) + errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) if ok { w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write(errorPage) + if _, err := w.Write(errorPage); err != nil { + logrus.Errorf("failed to respond error page to %s: %s", r.RemoteAddr, err) + } } else { http.Error(w, err.Error(), http.StatusNotFound) } @@ -198,7 +200,7 @@ func findMuxAnyDomain(host string) (http.Handler, error) { hostSplit := strings.Split(host, ".") n := len(hostSplit) if n <= 2 { - return nil, fmt.Errorf("missing subdomain in url") + return nil, errors.New("missing subdomain in url") } sd := strings.Join(hostSplit[:n-2], ".") if r, ok := httpRoutes.Load(sd); ok { diff --git a/internal/route/stream.go b/internal/route/stream.go index cedd67b..2f7b174 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -30,7 +30,7 @@ type StreamRoute struct { type StreamImpl interface { Setup() error Accept() (any, error) - Handle(any) error + Handle(conn any) error CloseListeners() String() string } diff --git a/internal/route/udp.go b/internal/route/udp.go index cecc16d..e3cde37 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -65,7 +65,6 @@ func (route *UDPRoute) Accept() (any, error) { buffer := make([]byte, udpBufferSize) nRead, srcAddr, err := in.ReadFromUDP(buffer) - if err != nil { return nil, err } diff --git a/internal/server/instance.go b/internal/server/instance.go index ba319f8..768f5b3 100644 --- a/internal/server/instance.go +++ b/internal/server/instance.go @@ -1,25 +1,25 @@ package server -var proxyServer, apiServer *server +var proxyServer, apiServer *Server -func InitProxyServer(opt Options) *server { +func InitProxyServer(opt Options) *Server { if proxyServer == nil { proxyServer = NewServer(opt) } return proxyServer } -func InitAPIServer(opt Options) *server { +func InitAPIServer(opt Options) *Server { if apiServer == nil { apiServer = NewServer(opt) } return apiServer } -func GetProxyServer() *server { +func GetProxyServer() *Server { return proxyServer } -func GetAPIServer() *server { +func GetAPIServer() *Server { return apiServer } diff --git a/internal/server/server.go b/internal/server/server.go index d29830b..fb94634 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "crypto/tls" + "errors" "log" "net/http" "time" @@ -11,7 +12,7 @@ import ( "golang.org/x/net/context" ) -type server struct { +type Server struct { Name string CertProvider *autocert.Provider http *http.Server @@ -38,7 +39,7 @@ func (l LogrusWrapper) Write(b []byte) (int, error) { return l.Logger.WriterLevel(logrus.ErrorLevel).Write(b) } -func NewServer(opt Options) (s *server) { +func NewServer(opt Options) (s *Server) { var httpSer, httpsSer *http.Server var httpHandler http.Handler @@ -76,7 +77,7 @@ func NewServer(opt Options) (s *server) { }, } } - return &server{ + return &Server{ Name: opt.Name, CertProvider: opt.CertProvider, http: httpSer, @@ -88,8 +89,8 @@ func NewServer(opt Options) (s *server) { // // If both are not set, this does nothing. // -// Start() is non-blocking -func (s *server) Start() { +// Start() is non-blocking. +func (s *Server) Start() { if s.http == nil && s.https == nil { return } @@ -112,7 +113,7 @@ func (s *server) Start() { } } -func (s *server) Stop() { +func (s *Server) Stop() { if s.http == nil && s.https == nil { return } @@ -133,13 +134,13 @@ func (s *server) Stop() { } } -func (s *server) Uptime() time.Duration { +func (s *Server) Uptime() time.Duration { return time.Since(s.startTime) } -func (s *server) handleErr(scheme string, err error) { - switch err { - case nil, http.ErrServerClosed: +func (s *Server) handleErr(scheme string, err error) { + switch { + case err == nil, errors.Is(err, http.ErrServerClosed): return default: logrus.Fatalf("failed to start %s %s server: %s", scheme, s.Name, err) diff --git a/internal/setup.go b/internal/setup.go index e71f54a..f067fec 100644 --- a/internal/setup.go +++ b/internal/setup.go @@ -1,7 +1,6 @@ package internal import ( - "fmt" "io" "log" "net/http" @@ -9,16 +8,18 @@ import ( "os" "path" - . "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/common" ) -var branch = GetEnv("GOPROXY_BRANCH", "v0.5") -var baseUrl = fmt.Sprintf("https://github.com/yusing/go-proxy/raw/%s", branch) -var requiredConfigs = []Config{ - {ConfigBasePath, true, false, ""}, - {ComposeFileName, false, true, ComposeExampleFileName}, - {path.Join(ConfigBasePath, ConfigFileName), false, true, ConfigExampleFileName}, -} +var ( + branch = common.GetEnv("GOPROXY_BRANCH", "v0.6") + baseURL = "https://github.com/yusing/go-proxy/raw/" + branch + requiredConfigs = []Config{ + {common.ConfigBasePath, true, false, ""}, + {common.ComposeFileName, false, true, common.ComposeExampleFileName}, + {path.Join(common.ConfigBasePath, common.ConfigFileName), false, true, common.ConfigExampleFileName}, + } +) type Config struct { Pathname string @@ -31,7 +32,9 @@ func Setup() { log.Println("setting up go-proxy") log.Println("branch:", branch) - os.Chdir("/setup") + if err := os.Chdir("/setup"); err != nil { + log.Fatalf("failed: %s\n", err) + } for _, config := range requiredConfigs { config.setup() @@ -83,6 +86,7 @@ func touch(pathname string) { log.Fatalf("failed: %s\n", err) } } + func fetch(remoteFilename string, outFileName string) { if hasFileOrDir(outFileName) { if remoteFilename == outFileName { @@ -94,7 +98,7 @@ func fetch(remoteFilename string, outFileName string) { } log.Printf("downloading %q\n", remoteFilename) - url, err := url.JoinPath(baseUrl, remoteFilename) + url, err := url.JoinPath(baseURL, remoteFilename) if err != nil { log.Fatalf("unexpected error: %s\n", err) } @@ -104,17 +108,19 @@ func fetch(remoteFilename string, outFileName string) { log.Fatalf("http request failed: %s\n", err) } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) if err != nil { + resp.Body.Close() log.Fatalf("error reading response body: %s\n", err) } err = os.WriteFile(outFileName, body, 0o644) if err != nil { + resp.Body.Close() log.Fatalf("failed to write to file: %s\n", err) } log.Printf("downloaded to %q\n", outFileName) + + resp.Body.Close() } diff --git a/internal/types/autocert_config.go b/internal/types/autocert_config.go index 2d2205e..d5687d7 100644 --- a/internal/types/autocert_config.go +++ b/internal/types/autocert_config.go @@ -2,12 +2,12 @@ package types type ( AutoCertConfig struct { - Email string `json:"email"` - Domains []string `yaml:",flow" json:"domains"` - CertPath string `yaml:"cert_path" json:"cert_path"` - KeyPath string `yaml:"key_path" json:"key_path"` - Provider string `json:"provider"` - Options AutocertProviderOpt `yaml:",flow" json:"options"` + Email string `json:"email,omitempty" yaml:"email"` + Domains []string `json:"domains,omitempty" yaml:",flow"` + CertPath string `json:"cert_path,omitempty" yaml:"cert_path"` + KeyPath string `json:"key_path,omitempty" yaml:"key_path"` + Provider string `json:"provider,omitempty" yaml:"provider"` + Options AutocertProviderOpt `json:"options,omitempty" yaml:",flow"` } AutocertProviderOpt map[string]any ) diff --git a/internal/types/config.go b/internal/types/config.go index d9ef600..bbbc21f 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -1,12 +1,12 @@ package types type Config struct { - Providers ProxyProviders `yaml:",flow" json:"providers"` - AutoCert AutoCertConfig `yaml:",flow" json:"autocert"` - ExplicitOnly bool `yaml:"explicit_only" json:"explicit_only"` - MatchDomains []string `yaml:"match_domains" json:"match_domains"` - TimeoutShutdown int `yaml:"timeout_shutdown" json:"timeout_shutdown"` - RedirectToHTTPS bool `yaml:"redirect_to_https" json:"redirect_to_https"` + Providers ProxyProviders `json:"providers" yaml:",flow"` + AutoCert AutoCertConfig `json:"autocert" yaml:",flow"` + ExplicitOnly bool `json:"explicit_only" yaml:"explicit_only"` + MatchDomains []string `json:"match_domains" yaml:"match_domains"` + TimeoutShutdown int `json:"timeout_shutdown" yaml:"timeout_shutdown"` + RedirectToHTTPS bool `json:"redirect_to_https" yaml:"redirect_to_https"` } func DefaultConfig() *Config { diff --git a/internal/types/proxy_providers.go b/internal/types/proxy_providers.go index eac0132..7ba4efa 100644 --- a/internal/types/proxy_providers.go +++ b/internal/types/proxy_providers.go @@ -1,6 +1,6 @@ package types type ProxyProviders struct { - Files []string `yaml:"include" json:"include"` // docker, file - Docker map[string]string `yaml:"docker" json:"docker"` + Files []string `json:"include" yaml:"include"` // docker, file + Docker map[string]string `json:"docker" yaml:"docker"` } diff --git a/internal/types/raw_entry.go b/internal/types/raw_entry.go index 1f4cf6b..90df23f 100644 --- a/internal/types/raw_entry.go +++ b/internal/types/raw_entry.go @@ -1,33 +1,35 @@ package types import ( - "fmt" "strconv" "strings" - . "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" H "github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" + U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) type ( RawEntry struct { + _ U.NoCopy + // raw entry object before validation // loaded from docker labels or yaml file - Alias string `yaml:"-" json:"-"` - Scheme string `yaml:"scheme" json:"scheme"` - Host string `yaml:"host" json:"host"` - Port string `yaml:"port" json:"port"` - NoTLSVerify bool `yaml:"no_tls_verify" json:"no_tls_verify,omitempty"` // https proxy only - PathPatterns []string `yaml:"path_patterns" json:"path_patterns,omitempty"` // http(s) proxy only - LoadBalance loadbalancer.Config `yaml:"load_balance" json:"load_balance"` - Middlewares D.NestedLabelMap `yaml:"middlewares" json:"middlewares,omitempty"` - Homepage *H.HomePageItem `yaml:"homepage" json:"homepage,omitempty"` + Alias string `json:"-" yaml:"-"` + Scheme string `json:"scheme" yaml:"scheme"` + Host string `json:"host" yaml:"host"` + Port string `json:"port" yaml:"port"` + NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only + PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only + LoadBalance loadbalancer.Config `json:"load_balance" yaml:"load_balance"` + Middlewares D.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"` + Homepage *H.HomePageItem `json:"homepage,omitempty" yaml:"homepage"` /* Docker only */ - *D.ProxyProperties `yaml:"-" json:"proxy_properties"` + *D.Container `json:"container" yaml:"-"` } RawEntries = F.Map[string, *RawEntry] @@ -36,21 +38,32 @@ type ( var NewProxyEntries = F.NewMapOf[string, *RawEntry] func (e *RawEntry) FillMissingFields() { - isDocker := e.ProxyProperties != nil + isDocker := e.Container != nil if !isDocker { - e.ProxyProperties = &D.ProxyProperties{} + e.Container = &D.Container{} + } + + if e.Host == "" { + switch { + case e.PrivateIP != "": + e.Host = e.PrivateIP + case e.PublicIP != "": + e.Host = e.PublicIP + default: + e.Host = "localhost" + } } lp, pp, extra := e.splitPorts() - if port, ok := ServiceNamePortMapTCP[e.ImageName]; ok { + if port, ok := common.ServiceNamePortMapTCP[e.ImageName]; ok { if pp == "" { pp = strconv.Itoa(port) } if e.Scheme == "" { e.Scheme = "tcp" } - } else if port, ok := ImageNamePortMap[e.ImageName]; ok { + } else if port, ok := common.ImageNamePortMap[e.ImageName]; ok { if pp == "" { pp = strconv.Itoa(port) } @@ -61,58 +74,68 @@ func (e *RawEntry) FillMissingFields() { pp = "443" } else if pp == "" { if p, ok := F.FirstValueOf(e.PrivatePortMapping); ok { - pp = fmt.Sprint(p.PrivatePort) + pp = U.PortString(p.PrivatePort) } else if !isDocker { pp = "80" } } - // replace private port with public port (if any) - if isDocker && e.NetworkMode != "host" { + // replace private port with public port if using public IP. + if e.Host == e.PublicIP { if p, ok := e.PrivatePortMapping[pp]; ok { - pp = fmt.Sprint(p.PublicPort) + pp = U.PortString(p.PublicPort) } if _, ok := e.PublicPortMapping[pp]; !ok { // port is not exposed, but specified // try to fallback to first public port if p, ok := F.FirstValueOf(e.PublicPortMapping); ok { - pp = fmt.Sprint(p.PublicPort) + pp = U.PortString(p.PublicPort) + } + } + } + // replace public port with private port if using private IP. + if e.Host == e.PrivateIP { + if p, ok := e.PublicPortMapping[pp]; ok { + pp = U.PortString(p.PrivatePort) + } + if _, ok := e.PrivatePortMapping[pp]; !ok { // port is not exposed, but specified + // try to fallback to first private port + if p, ok := F.FirstValueOf(e.PrivatePortMapping); ok { + pp = U.PortString(p.PrivatePort) } } } if e.Scheme == "" && isDocker { - if p, ok := e.PublicPortMapping[pp]; ok && p.Type == "udp" { + switch { + case e.Host == e.PublicIP && e.PublicPortMapping[pp].Type == "udp": + e.Scheme = "udp" + case e.Host == e.PrivateIP && e.PrivatePortMapping[pp].Type == "udp": e.Scheme = "udp" } } if e.Scheme == "" { - if lp != "" { + switch { + case lp != "": e.Scheme = "tcp" - } else if strings.HasSuffix(pp, "443") { + case strings.HasSuffix(pp, "443"): e.Scheme = "https" - } else if _, ok := WellKnownHTTPPorts[pp]; ok { - e.Scheme = "http" - } else { - // assume its http + default: // assume its http e.Scheme = "http" } } - if e.Host == "" { - e.Host = "localhost" - } if e.IdleTimeout == "" { - e.IdleTimeout = IdleTimeoutDefault + e.IdleTimeout = common.IdleTimeoutDefault } if e.WakeTimeout == "" { - e.WakeTimeout = WakeTimeoutDefault + e.WakeTimeout = common.WakeTimeoutDefault } if e.StopTimeout == "" { - e.StopTimeout = StopTimeoutDefault + e.StopTimeout = common.StopTimeoutDefault } if e.StopMethod == "" { - e.StopMethod = StopMethodDefault + e.StopMethod = common.StopMethodDefault } e.Port = joinPorts(lp, pp, extra) diff --git a/internal/utils/fs.go b/internal/utils/fs.go index 2332220..82ee788 100644 --- a/internal/utils/fs.go +++ b/internal/utils/fs.go @@ -7,7 +7,7 @@ import ( ) // Recursively lists all files in a directory until `maxDepth` is reached -// Returns a slice of file paths relative to `dir` +// Returns a slice of file paths relative to `dir`. func ListFiles(dir string, maxDepth int) ([]string, error) { entries, err := os.ReadDir(dir) if err != nil { diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index b578e5b..2c0dfe7 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -4,9 +4,8 @@ import ( "sync" "github.com/puzpuzpuz/xsync/v3" - "gopkg.in/yaml.v3" - E "github.com/yusing/go-proxy/internal/error" + "gopkg.in/yaml.v3" ) type Map[KT comparable, VT any] struct { @@ -25,6 +24,17 @@ func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) { return } +// MapFind iterates over the map and returns the first value +// that satisfies the given criteria. The iteration is stopped +// once a value is found. If no value satisfies the criteria, +// the function returns the zero value of CT. +// +// The criteria function takes a value of type VT and returns a +// value of type CT and a boolean indicating whether the value +// satisfies the criteria. The boolean value is used to determine +// whether the iteration should be stopped. +// +// The function is safe for concurrent use. func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bool)) (_ CT) { result := make(chan CT, 1) @@ -49,13 +59,15 @@ func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bo } } -// MergeFrom add contents from another `Map`, ignore duplicated keys +// MergeFrom merges the contents of another Map into this one, ignoring duplicated keys. // // Parameters: -// - other: `Map` of values to add from // -// Return: -// - Map: a `Map` of duplicated keys-value pairs +// other: Map of values to add from +// +// Returns: +// +// Map of duplicated keys-value pairs func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] { dups := NewMapOf[KT, VT]() @@ -70,6 +82,15 @@ func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] { return dups } +// RangeAll calls the given function for each key-value pair in the map. +// +// Parameters: +// +// do: function to call for each key-value pair +// +// Returns: +// +// nothing func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) { m.Range(func(k KT, v VT) bool { do(k, v) @@ -77,6 +98,16 @@ func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) { }) } +// RangeAllParallel calls the given function for each key-value pair in the map, +// in parallel. The map is not safe for modification from within the function. +// +// Parameters: +// +// do: function to call for each key-value pair +// +// Returns: +// +// nothing func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { var wg sync.WaitGroup wg.Add(m.Size()) @@ -91,6 +122,15 @@ func (m Map[KT, VT]) RangeAllParallel(do func(k KT, v VT)) { wg.Wait() } +// RemoveAll removes all key-value pairs from the map where the value matches the given criteria. +// +// Parameters: +// +// criteria: function to determine whether a value should be removed +// +// Returns: +// +// nothing func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) { m.Range(func(k KT, v VT) bool { if criteria(v) { @@ -105,6 +145,17 @@ func (m Map[KT, VT]) Has(k KT) bool { return ok } +// UnmarshalFromYAML unmarshals a yaml byte slice into the map. +// +// It overwrites all existing key-value pairs in the map. +// +// Parameters: +// +// data: yaml byte slice to unmarshal +// +// Returns: +// +// error: if the unmarshaling fails func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError { if m.Size() != 0 { return E.FailedWhy("unmarshal from yaml", "map is not empty") diff --git a/internal/utils/io.go b/internal/utils/io.go index 5987b7c..406197b 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -12,7 +12,7 @@ import ( E "github.com/yusing/go-proxy/internal/error" ) -// TODO: move to "utils/io" +// TODO: move to "utils/io". type ( FileReader struct { Path string @@ -108,7 +108,7 @@ func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) } -func LoadJson[T any](path string, pointer *T) E.NestedError { +func LoadJSON[T any](path string, pointer *T) E.NestedError { data, err := E.Check(os.ReadFile(path)) if err.HasError() { return err @@ -116,7 +116,7 @@ func LoadJson[T any](path string, pointer *T) E.NestedError { return E.From(json.Unmarshal(data, pointer)) } -func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError { +func SaveJSON[T any](path string, pointer *T, perm os.FileMode) E.NestedError { data, err := E.Check(json.Marshal(pointer)) if err.HasError() { return err diff --git a/internal/utils/nocopy.go b/internal/utils/nocopy.go new file mode 100644 index 0000000..f344374 --- /dev/null +++ b/internal/utils/nocopy.go @@ -0,0 +1,8 @@ +package utils + +// empty struct that implements Locker interface +// for hinting that no copy should be performed. +type NoCopy struct{} + +func (*NoCopy) Lock() {} +func (*NoCopy) Unlock() {} diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 12d564e..7a8f0b9 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -3,6 +3,7 @@ package utils import ( "bytes" "encoding/json" + "errors" "fmt" "reflect" "strconv" @@ -14,10 +15,12 @@ import ( "gopkg.in/yaml.v3" ) -type SerializedObject = map[string]any -type Converter interface { - ConvertFrom(value any) (any, E.NestedError) -} +type ( + SerializedObject = map[string]any + Converter interface { + ConvertFrom(value any) (any, E.NestedError) + } +) func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { var i any @@ -37,11 +40,16 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { return nil } - errors := E.NewBuilder("yaml validation error") - for _, e := range err.(*jsonschema.ValidationError).Causes { - errors.AddE(e) + var valErr *jsonschema.ValidationError + if !errors.As(err, &valErr) { + return E.UnexpectedError(err) } - return errors.Build() + + b := E.NewBuilder("yaml validation error") + for _, e := range valErr.Causes { + b.AddE(e) + } + return b.Build() } // Serialize converts the given data into a map[string]any representation. @@ -80,7 +88,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) { result[key.String()] = value.MapIndex(key).Interface() } case reflect.Struct: - for i := 0; i < value.NumField(); i++ { + for i := range value.NumField() { field := value.Type().Field(i) if !field.IsExported() { continue @@ -91,9 +99,10 @@ func Serialize(data any) (SerializedObject, E.NestedError) { } // If the json tag is not empty, use it as the key - if jsonTag != "" { + switch { + case jsonTag != "": result[jsonTag] = value.Field(i).Interface() - } else if field.Anonymous { + case field.Anonymous: // If the field is an embedded struct, add its fields to the result fieldMap, err := Serialize(value.Field(i).Interface()) if err != nil { @@ -102,7 +111,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) { for k, v := range fieldMap { result[k] = v } - } else { + default: result[field.Name] = value.Field(i).Interface() } } @@ -147,7 +156,8 @@ func Deserialize(src SerializedObject, dst any) E.NestedError { // TODO: use E.Builder to collect errors from all fields - if dstV.Kind() == reflect.Struct { + switch dstV.Kind() { + case reflect.Struct: mapping := make(map[string]reflect.Value) for _, field := range reflect.VisibleFields(dstT) { mapping[ToLowerNoSnake(field.Name)] = dstV.FieldByName(field.Name) @@ -162,7 +172,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError { return E.Unexpected("field", k).Subjectf("%T", dst) } } - } else if dstV.Kind() == reflect.Map && dstT.Key().Kind() == reflect.String { + case reflect.Map: if dstV.IsNil() { dstV.Set(reflect.MakeMap(dstT)) } @@ -174,8 +184,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError { } dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp) } - return nil - } else { + default: return E.Unsupported("target type", fmt.Sprintf("%T", dst)) } @@ -362,7 +371,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N return true, Convert(reflect.ValueOf(tmp), dst) } -func DeserializeJson(j map[string]string, target any) E.NestedError { +func DeserializeJSON(j map[string]string, target any) E.NestedError { data, err := E.Check(json.Marshal(j)) if err != nil { return err diff --git a/internal/utils/string.go b/internal/utils/string.go index ea2376c..1504cc5 100644 --- a/internal/utils/string.go +++ b/internal/utils/string.go @@ -9,7 +9,7 @@ import ( "golang.org/x/text/language" ) -// TODO: support other languages +// TODO: support other languages. var titleCaser = cases.Title(language.AmericanEnglish) func CommaSeperatedList(s string) []string { @@ -31,3 +31,7 @@ func ExtractPort(fullURL string) (int, error) { } return strconv.Atoi(url.Port()) } + +func PortString(port uint16) string { + return strconv.FormatUint(uint64(port), 10) +} diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index bb184d3..0b58827 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -92,7 +92,6 @@ func ExpectType[T any](t *testing.T, got any) (_ T) { _, ok := got.(T) if !ok { t.Fatalf("expected type %s, got %s", tExpect, reflect.TypeOf(got).Elem()) - t.FailNow() return } return got.(T) diff --git a/internal/watcher/config_file_watcher.go b/internal/watcher/config_file_watcher.go index ddb3ea8..4f7c512 100644 --- a/internal/watcher/config_file_watcher.go +++ b/internal/watcher/config_file_watcher.go @@ -7,10 +7,12 @@ import ( "github.com/yusing/go-proxy/internal/common" ) -var configDirWatcher *dirWatcher -var configDirWatcherMu sync.Mutex +var ( + configDirWatcher *DirWatcher + configDirWatcherMu sync.Mutex +) -// create a new file watcher for file under ConfigBasePath +// create a new file watcher for file under ConfigBasePath. func NewConfigFileWatcher(filename string) Watcher { configDirWatcherMu.Lock() defer configDirWatcherMu.Unlock() diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index acbd7b3..8a9d9e4 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -13,7 +13,7 @@ import ( "github.com/yusing/go-proxy/internal/watcher/events" ) -type dirWatcher struct { +type DirWatcher struct { dir string w *fsnotify.Watcher @@ -26,7 +26,7 @@ type dirWatcher struct { ctx context.Context } -func NewDirectoryWatcher(ctx context.Context, dirPath string) *dirWatcher { +func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() if err != nil { @@ -35,7 +35,7 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *dirWatcher { if err = w.Add(dirPath); err != nil { logrus.Panicf("unable to create fs watcher: %s", err) } - helper := &dirWatcher{ + helper := &DirWatcher{ dir: dirPath, w: w, fwMap: F.NewMapOf[string, *fileWatcher](), @@ -47,11 +47,11 @@ func NewDirectoryWatcher(ctx context.Context, dirPath string) *dirWatcher { return helper } -func (h *dirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) { +func (h *DirWatcher) Events(_ context.Context) (<-chan Event, <-chan E.NestedError) { return h.eventCh, h.errCh } -func (h *dirWatcher) Add(relPath string) *fileWatcher { +func (h *DirWatcher) Add(relPath string) Watcher { h.mu.Lock() defer h.mu.Unlock() @@ -85,7 +85,7 @@ func (h *dirWatcher) Add(relPath string) *fileWatcher { return s } -func (h *dirWatcher) start() { +func (h *DirWatcher) start() { defer close(h.eventCh) defer h.w.Close()