diff --git a/.golangci.yml b/.golangci.yml index 078682a..8eb57f6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,15 +2,16 @@ version: "2" linters: default: all disable: - - bodyclose + # - bodyclose - containedctx - - contextcheck + # - contextcheck - cyclop - depguard - - dupl + # - dupl - err113 - exhaustive - exhaustruct + - funcorder - forcetypeassert - gochecknoglobals - gochecknoinits @@ -18,7 +19,6 @@ linters: - goconst - gocyclo - gomoddirectives - - gosec - gosmopolitan - ireturn - lll @@ -27,12 +27,10 @@ linters: - mnd - nakedret - nestif - - nilnil - nlreturn - - noctx - nonamedreturns - paralleltest - - prealloc + - revive - rowserrcheck - sqlclosecheck - tagliatelle diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index bf24f6d..ed25d81 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -21,7 +21,7 @@ lint: - markdownlint - yamllint enabled: - - checkov@3.2.416 + - checkov@3.2.432 - golangci-lint2@2.1.6 - hadolint@2.12.1-beta - actionlint@1.7.7 @@ -32,7 +32,7 @@ lint: - prettier@3.5.3 - shellcheck@0.10.0 - shfmt@3.6.0 - - trufflehog@3.88.29 + - trufflehog@3.88.33 actions: disabled: - trunk-announce diff --git a/README.md b/README.md index d091204..85e5172 100755 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ A lightweight, simple, and performant reverse proxy with WebUI.
EN | 中文
+Have questions? Ask [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)! (Thanks to [@ismesid](https://github.com/arevindh)) + diff --git a/README_CHT.md b/README_CHT.md index f9e0ca2..bd95fc4 100644 --- a/README_CHT.md +++ b/README_CHT.md @@ -16,6 +16,8 @@
EN | 中文
+有疑問? 問 [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)!(鳴謝 [@ismesid](https://github.com/arevindh)) + diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 7c891b8..d2167b6 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -1,11 +1,14 @@ package main import ( + "os" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/agent/pkg/server" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/metrics/systeminfo" httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/task" @@ -14,6 +17,12 @@ import ( ) func main() { + writer := zerolog.ConsoleWriter{ + Out: os.Stderr, + TimeFormat: "01-02 15:04", + } + zerolog.TimeFieldFormat = writer.TimeFormat + log.Logger = zerolog.New(writer).Level(zerolog.InfoLevel).With().Timestamp().Logger() ca := &agent.PEMPair{} err := ca.Load(env.AgentCACert) if err != nil { @@ -34,11 +43,11 @@ func main() { gperr.LogFatal("init SSL error", err) } - logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion()) - logging.Info().Msgf("Agent name: %s", env.AgentName) - logging.Info().Msgf("Agent port: %d", env.AgentPort) + log.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion()) + log.Info().Msgf("Agent name: %s", env.AgentName) + log.Info().Msgf("Agent port: %d", env.AgentPort) - logging.Info().Msg(` + log.Info().Msg(` Tips: 1. To change the agent name, you can set the AGENT_NAME environment variable. 2. To change the agent port, you can set the AGENT_PORT environment variable. @@ -54,7 +63,7 @@ Tips: server.StartAgentServer(t, opts) if socketproxy.ListenAddr != "" { - logging.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr) + log.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr) opts := httpServer.Options{ Name: "docker", HTTPAddr: socketproxy.ListenAddr, diff --git a/agent/go.mod b/agent/go.mod index 567f762..8b6e7d0 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy => .. replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy +replace github.com/yusing/go-proxy/internal/utils => ../internal/utils + replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30 replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250523121925-f87c3159e327 @@ -15,6 +17,7 @@ require ( github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.10.0 github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000 + github.com/yusing/go-proxy/internal/utils v0.0.0 github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000 ) diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index 9840486..5887a66 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -15,8 +15,8 @@ import ( "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/agent/pkg/certs" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/pkg" ) @@ -115,7 +115,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) cfg.name = string(name) - cfg.l = logging.With().Str("agent", cfg.name).Logger() + cfg.l = log.With().Str("agent", cfg.name).Logger() // check agent version agentVersionBytes, _, err := cfg.Fetch(ctx, EndpointVersion) @@ -127,10 +127,10 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) agentVersion := pkg.ParseVersion(cfg.version) if serverVersion.IsNewerMajorThan(agentVersion) { - logging.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion) + log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion) } - logging.Info().Msgf("agent %q initialized", cfg.name) + log.Info().Msgf("agent %q initialized", cfg.name) return nil } diff --git a/agent/pkg/handler/check_health_test.go b/agent/pkg/handler/check_health_test.go index 4633ed6..2fc023f 100644 --- a/agent/pkg/handler/check_health_test.go +++ b/agent/pkg/handler/check_health_test.go @@ -172,9 +172,9 @@ func TestCheckHealthTCPUDP(t *testing.T) { { name: "InvalidHost", scheme: "tcp", - host: "invalid", + host: "", port: 8080, - expectedStatus: http.StatusOK, + expectedStatus: http.StatusBadRequest, expectedHealthy: false, }, { @@ -188,9 +188,17 @@ func TestCheckHealthTCPUDP(t *testing.T) { { name: "InvalidHost", scheme: "udp", - host: "invalid", + host: "", port: 8080, - expectedStatus: http.StatusOK, + expectedStatus: http.StatusBadRequest, + expectedHealthy: false, + }, + { + name: "Port in both host and port", + scheme: "tcp", + host: "localhost:1234", + port: 1234, + expectedStatus: http.StatusBadRequest, expectedHealthy: false, }, } @@ -208,9 +216,11 @@ func TestCheckHealthTCPUDP(t *testing.T) { require.Equal(t, recorder.Code, tt.expectedStatus) - var result health.HealthCheckResult - require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) - require.Equal(t, result.Healthy, tt.expectedHealthy) + if tt.expectedStatus == http.StatusOK { + var result health.HealthCheckResult + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result)) + require.Equal(t, result.Healthy, tt.expectedHealthy) + } }) } } diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go index 9d07556..1835616 100644 --- a/agent/pkg/handler/handler.go +++ b/agent/pkg/handler/handler.go @@ -1,17 +1,14 @@ package handler import ( - "context" "fmt" - "net" "net/http" - "net/http/httputil" - "time" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/internal/metrics/systeminfo" "github.com/yusing/go-proxy/pkg" + socketproxy "github.com/yusing/go-proxy/socketproxy/pkg" ) type ServeMux struct{ *http.ServeMux } @@ -24,26 +21,6 @@ func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) { mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler) } -var dialer = &net.Dialer{KeepAlive: 1 * time.Second} - -func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, "unix", env.DockerSocket) -} - -func dockerSocketHandler() http.HandlerFunc { - rp := httputil.ReverseProxy{ - Director: func(r *http.Request) { - r.URL.Scheme = "http" - r.URL.Host = "api.moby.localhost" - r.RequestURI = r.URL.String() - }, - Transport: &http.Transport{ - DialContext: dialDockerSocket, - }, - } - return rp.ServeHTTP -} - func NewAgentHandler() http.Handler { mux := ServeMux{http.NewServeMux()} @@ -54,6 +31,6 @@ func NewAgentHandler() http.Handler { }) mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth) mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP) - mux.ServeMux.HandleFunc("/", dockerSocketHandler()) + mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler(env.DockerSocket)) return mux } diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go index 9be4631..4b9e2b7 100644 --- a/agent/pkg/server/server.go +++ b/agent/pkg/server/server.go @@ -6,9 +6,9 @@ import ( "fmt" "net/http" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/agent/pkg/handler" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/task" ) @@ -33,12 +33,11 @@ func StartAgentServer(parent task.Parent, opt Options) { tlsConfig.ClientAuth = tls.NoClientCert } - logger := logging.GetLogger() agentServer := &http.Server{ Addr: fmt.Sprintf(":%d", opt.Port), Handler: handler.NewAgentHandler(), TLSConfig: tlsConfig, } - server.Start(parent, agentServer, nil, logger) + server.Start(parent, agentServer, nil, &log.Logger) } diff --git a/cmd/main.go b/cmd/main.go index 125ec69..c2cb3c1 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,6 +4,7 @@ import ( "os" "sync" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/auth" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -35,8 +36,8 @@ func main() { initProfiling() logging.InitLogger(os.Stderr, memlogger.GetMemLogger()) - logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion()) - logging.Trace().Msg("trace enabled") + log.Info().Msgf("GoDoxy version %s", pkg.GetVersion()) + log.Trace().Msg("trace enabled") parallel( dnsproviders.InitProviders, homepage.InitIconListCache, @@ -45,7 +46,7 @@ func main() { ) if common.APIJWTSecret == nil { - logging.Warn().Msg("API_JWT_SECRET is not set, using random key") + log.Warn().Msg("API_JWT_SECRET is not set, using random key") common.APIJWTSecret = common.RandomJWTKey() } @@ -62,7 +63,7 @@ func main() { Proxy: true, }) if err := auth.Initialize(); err != nil { - logging.Fatal().Err(err).Msg("failed to initialize authentication") + log.Fatal().Err(err).Msg("failed to initialize authentication") } // API Handler needs to start after auth is initialized. cfg.StartServers(&config.StartServersOptions{ @@ -78,7 +79,7 @@ func main() { func prepareDirectory(dir string) { if _, err := os.Stat(dir); os.IsNotExist(err) { if err = os.MkdirAll(dir, 0o755); err != nil { - logging.Fatal().Msgf("failed to create directory %s: %v", dir, err) + log.Fatal().Msgf("failed to create directory %s: %v", dir, err) } } } diff --git a/go.mod b/go.mod index 38af281..df26fd2 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy/agent => ./agent replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders +replace github.com/yusing/go-proxy/internal/utils => ./internal/utils + replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.0.0-20250523122447-f078841dec22 replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30 @@ -45,7 +47,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/yusing/go-proxy/agent v0.0.0-00010101000000-000000000000 github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-00010101000000-000000000000 - go.uber.org/atomic v1.11.0 + github.com/yusing/go-proxy/internal/utils v0.0.0 ) require ( @@ -219,6 +221,7 @@ require ( go.opentelemetry.io/otel v1.36.0 // indirect go.opentelemetry.io/otel/metric v1.36.0 // indirect go.opentelemetry.io/otel/trace v1.36.0 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/mock v0.5.2 // indirect go.uber.org/multierr v1.11.0 // indirect @@ -226,7 +229,7 @@ require ( golang.org/x/mod v0.24.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.25.0 + golang.org/x/text v0.25.0 // indirect golang.org/x/tools v0.33.0 // indirect google.golang.org/api v0.234.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect diff --git a/internal/acl/config.go b/internal/acl/config.go index 2b7c5ef..8d2f306 100644 --- a/internal/acl/config.go +++ b/internal/acl/config.go @@ -5,9 +5,9 @@ import ( "time" "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/accesslog" "github.com/yusing/go-proxy/internal/maxmind" "github.com/yusing/go-proxy/internal/task" @@ -45,7 +45,7 @@ func (c *checkCache) Expired() bool { return c.created.Add(cacheTTL).Before(utils.TimeNow()) } -//TODO: add stats +// TODO: add stats const ( ACLAllow = "allow" @@ -97,7 +97,7 @@ func (c *Config) Start(parent *task.Task) gperr.Error { if c.valErr != nil { return c.valErr } - logging.Info(). + log.Info(). Str("default", c.Default). Bool("allow_local", c.allowLocal). Int("allow_rules", len(c.Allow)). diff --git a/internal/acl/matcher_test.go b/internal/acl/matcher_test.go index 384b9b8..0d015b4 100644 --- a/internal/acl/matcher_test.go +++ b/internal/acl/matcher_test.go @@ -6,7 +6,7 @@ import ( "testing" maxmind "github.com/yusing/go-proxy/internal/maxmind/types" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) func TestMatchers(t *testing.T) { @@ -16,7 +16,7 @@ func TestMatchers(t *testing.T) { } var mathers Matchers - err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false) + err := serialization.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false) if err != nil { t.Fatal(err) } diff --git a/internal/acl/tcp_listener.go b/internal/acl/tcp_listener.go index b6b4a7a..59e40ee 100644 --- a/internal/acl/tcp_listener.go +++ b/internal/acl/tcp_listener.go @@ -22,12 +22,12 @@ func (noConn) SetDeadline(t time.Time) error { return nil } func (noConn) SetReadDeadline(t time.Time) error { return nil } func (noConn) SetWriteDeadline(t time.Time) error { return nil } -func (cfg *Config) WrapTCP(lis net.Listener) net.Listener { - if cfg == nil { +func (c *Config) WrapTCP(lis net.Listener) net.Listener { + if c == nil { return lis } return &TCPListener{ - acl: cfg, + acl: c, lis: lis, } } diff --git a/internal/api/v1/certapi/renew.go b/internal/api/v1/certapi/renew.go index b274ef4..084db82 100644 --- a/internal/api/v1/certapi/renew.go +++ b/internal/api/v1/certapi/renew.go @@ -3,9 +3,9 @@ package certapi import ( "net/http" + "github.com/rs/zerolog/log" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" ) @@ -36,7 +36,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) { gperr.LogError("failed to obtain cert", err) _ = gpwebsocket.WriteText(conn, err.Error()) } else { - logging.Info().Msg("cert obtained successfully") + log.Info().Msg("cert obtained successfully") } }() for { diff --git a/internal/api/v1/dockerapi/logs.go b/internal/api/v1/dockerapi/logs.go index 385b59e..c313fed 100644 --- a/internal/api/v1/dockerapi/logs.go +++ b/internal/api/v1/dockerapi/logs.go @@ -9,12 +9,13 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/pkg/stdcopy" "github.com/gorilla/websocket" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" "github.com/yusing/go-proxy/internal/task" ) +// FIXME: agent logs not updating. func Logs(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() server := r.PathValue("server") @@ -68,7 +69,7 @@ func Logs(w http.ResponseWriter, r *http.Request) { if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) { return } - logging.Err(err). + log.Err(err). Str("server", server). Str("container", containerID). Msg("failed to de-multiplex logs") diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go index 5d3d022..0fc0335 100644 --- a/internal/auth/oauth_refresh.go +++ b/internal/auth/oauth_refresh.go @@ -11,9 +11,9 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/jsonstore" - "github.com/yusing/go-proxy/internal/logging" "golang.org/x/oauth2" ) @@ -108,11 +108,11 @@ func storeOAuthRefreshToken(sessionID sessionID, username, token string) { RefreshToken: token, Expiry: time.Now().Add(defaultRefreshTokenExpiry), }) - logging.Debug().Str("username", username).Msg("stored oauth refresh token") + log.Debug().Str("username", username).Msg("stored oauth refresh token") } func invalidateOAuthRefreshToken(sessionID sessionID) { - logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token") + log.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token") oauthRefreshTokens.Delete(string(sessionID)) } @@ -127,7 +127,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) signed, err := jwtToken.SignedString(common.APIJWTSecret) if err != nil { - logging.Err(err).Msg("failed to sign session token") + log.Err(err).Msg("failed to sign session token") return } SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL) @@ -190,7 +190,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut return nil, refreshToken.err } - idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken) + idTokenJWT, idToken, err := auth.getIDToken(ctx, newToken) if err != nil { refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err) return nil, refreshToken.err @@ -205,7 +205,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut sessionID := newSessionID() - logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") + log.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken) refreshToken.result = &RefreshResult{ diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 4c8c0b2..7f9d4ed 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -12,9 +12,9 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/utils" "golang.org/x/oauth2" @@ -38,8 +38,8 @@ type ( const ( CookieOauthState = "godoxy_oidc_state" - CookieOauthToken = "godoxy_oauth_token" - CookieOauthSessionToken = "godoxy_session_token" + CookieOauthToken = "godoxy_oauth_token" //nolint:gosec + CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec ) const ( @@ -79,7 +79,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all endSessionURL, err := url.Parse(provider.EndSessionEndpoint()) if err != nil && provider.EndSessionEndpoint() != "" { // non critical, just warn - logging.Warn(). + log.Warn(). Str("issuer", issuerURL). Err(err). Msg("failed to parse end session URL") @@ -129,7 +129,7 @@ func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption { return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath) } -func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) { +func (auth *OIDCProvider) getIDToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) { idTokenJWT, ok := oauthToken.Extra("id_token").(string) if !ok { return "", nil, errMissingIDToken @@ -176,7 +176,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { return } // clear cookies then redirect to home - logging.Err(err).Msg("failed to refresh token") + log.Err(err).Msg("failed to refresh token") auth.clearCookie(w, r) http.Redirect(w, r, "/", http.StatusFound) return @@ -201,11 +201,12 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) { func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool { userAllowed := slices.Contains(auth.allowedUsers, user) - if !userAllowed { - return false + if userAllowed { + return true } if len(auth.allowedGroups) == 0 { - return true + // user is not allowed, but no groups are allowed + return false } return len(utils.Intersect(groups, auth.allowedGroups)) > 0 } @@ -257,7 +258,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http return } - idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token) + idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token) if err != nil { gphttp.ServerError(w, r, err) return diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 0934d96..a71d429 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -5,13 +5,16 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/x509" + "net/http" "os" "regexp" "github.com/go-acme/lego/v4/certcrypto" + "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" + "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils" ) @@ -22,13 +25,19 @@ type Config struct { KeyPath string `json:"key_path,omitempty"` ACMEKeyPath string `json:"acme_key_path,omitempty"` Provider string `json:"provider,omitempty"` + CADirURL string `json:"ca_dir_url,omitempty"` Options map[string]any `json:"options,omitempty"` + + HTTPClient *http.Client `json:"-"` // for tests only + + challengeProvider challenge.Provider } var ( ErrMissingDomain = gperr.New("missing field 'domains'") ErrMissingEmail = gperr.New("missing field 'email'") ErrMissingProvider = gperr.New("missing field 'provider'") + ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'") ErrInvalidDomain = gperr.New("invalid domain") ErrUnknownProvider = gperr.New("unknown provider") ) @@ -36,6 +45,7 @@ var ( const ( ProviderLocal = "local" ProviderPseudo = "pseudo" + ProviderCustom = "custom" ) var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`) @@ -52,6 +62,10 @@ func (cfg *Config) Validate() gperr.Error { } b := gperr.NewBuilder("autocert errors") + if cfg.Provider == ProviderCustom && cfg.CADirURL == "" { + b.Add(ErrMissingCADirURL) + } + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if len(cfg.Domains) == 0 { b.Add(ErrMissingDomain) @@ -59,24 +73,34 @@ func (cfg *Config) Validate() gperr.Error { if cfg.Email == "" { b.Add(ErrMissingEmail) } - for i, d := range cfg.Domains { - if !domainOrWildcardRE.MatchString(d) { - b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) + if cfg.Provider != ProviderCustom { + for i, d := range cfg.Domains { + if !domainOrWildcardRE.MatchString(d) { + b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) + } } } // check if provider is implemented providerConstructor, ok := Providers[cfg.Provider] if !ok { - b.Add(ErrUnknownProvider. - Subject(cfg.Provider). - With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers)))) + if cfg.Provider != ProviderCustom { + b.Add(ErrUnknownProvider. + Subject(cfg.Provider). + With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers)))) + } } else { - _, err := providerConstructor(cfg.Options) + provider, err := providerConstructor(cfg.Options) if err != nil { b.Add(err) + } else { + cfg.challengeProvider = provider } } } + + if cfg.challengeProvider == nil { + cfg.challengeProvider, _ = Providers[ProviderLocal](nil) + } return b.Error() } @@ -100,8 +124,7 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) { if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if privKey, err = cfg.LoadACMEKey(); err != nil { - logging.Info().Err(err).Msg("load ACME private key failed") - logging.Info().Msg("generate new ACME private key") + log.Info().Err(err).Msg("failed to load ACME private key, generating a now one") privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, nil, gperr.New("generate ACME private key").With(err) @@ -118,12 +141,23 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) { } legoCfg := lego.NewConfig(user) - legoCfg.Certificate.KeyType = certcrypto.RSA2048 + legoCfg.Certificate.KeyType = certcrypto.EC256 + + if cfg.HTTPClient != nil { + legoCfg.HTTPClient = cfg.HTTPClient + } + + if cfg.CADirURL != "" { + legoCfg.CADirURL = cfg.CADirURL + } return user, legoCfg, nil } func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) { + if common.IsTest { + return nil, os.ErrNotExist + } data, err := os.ReadFile(cfg.ACMEKeyPath) if err != nil { return nil, err @@ -132,6 +166,9 @@ func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) { } func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error { + if common.IsTest { + return nil + } data, err := x509.MarshalECPrivateKey(key) if err != nil { return err diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 8c2585b..12c0125 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -5,18 +5,20 @@ import ( "crypto/x509" "errors" "fmt" + "maps" "os" "path" - "reflect" - "sort" + "slices" + "strings" "time" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -76,13 +78,11 @@ func (p *Provider) ObtainCert() error { } if p.cfg.Provider == ProviderPseudo { - t := time.NewTicker(1000 * time.Millisecond) - defer t.Stop() - logging.Info().Msg("init client for pseudo provider") - <-t.C - logging.Info().Msg("registering acme for pseudo provider") - <-t.C - logging.Info().Msg("obtained cert for pseudo provider") + log.Info().Msg("init client for pseudo provider") + <-time.After(time.Second) + log.Info().Msg("registering acme for pseudo provider") + <-time.After(time.Second) + log.Info().Msg("obtained cert for pseudo provider") return nil } @@ -107,7 +107,7 @@ func (p *Provider) ObtainCert() error { }) if err != nil { p.legoCert = nil - logging.Err(err).Msg("cert renew failed, fallback to obtain") + log.Err(err).Msg("cert renew failed, fallback to obtain") } else { p.legoCert = cert } @@ -154,7 +154,7 @@ func (p *Provider) LoadCert() error { p.tlsCert = &cert p.certExpiries = expiries - logging.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn()))) + log.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn()))) return p.renewIfNeeded() } @@ -219,13 +219,7 @@ func (p *Provider) initClient() error { return err } - generator := Providers[p.cfg.Provider] - legoProvider, pErr := generator(p.cfg.Options) - if pErr != nil { - return pErr - } - - err = legoClient.Challenge.SetDNS01Provider(legoProvider) + err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider) if err != nil { return err } @@ -240,7 +234,7 @@ func (p *Provider) registerACME() error { } if reg, err := p.client.Registration.ResolveAccountByKey(); err == nil { p.user.Registration = reg - logging.Info().Msg("reused acme registration from private key") + log.Info().Msg("reused acme registration from private key") return nil } @@ -249,11 +243,14 @@ func (p *Provider) registerACME() error { return err } p.user.Registration = reg - logging.Info().Interface("reg", reg).Msg("acme registered") + log.Info().Interface("reg", reg).Msg("acme registered") return nil } func (p *Provider) saveCert(cert *certificate.Resource) error { + if common.IsTest { + return nil + } /* This should have been done in setup but double check is always a good choice.*/ _, err := os.Stat(path.Dir(p.cfg.CertPath)) @@ -283,22 +280,19 @@ func (p *Provider) certState() CertState { return CertStateExpired } - certDomains := make([]string, len(p.certExpiries)) - wantedDomains := make([]string, len(p.cfg.Domains)) - i := 0 - for domain := range p.certExpiries { - certDomains[i] = domain - i++ - } - copy(wantedDomains, p.cfg.Domains) - sort.Strings(wantedDomains) - sort.Strings(certDomains) - - if !reflect.DeepEqual(certDomains, wantedDomains) { - logging.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) + if len(p.certExpiries) != len(p.cfg.Domains) { return CertStateMismatch } + for i := range len(p.cfg.Domains) { + if _, ok := p.certExpiries[p.cfg.Domains[i]]; !ok { + log.Info().Msgf("autocert domains mismatch: cert: %s, wanted: %s", + strings.Join(slices.Collect(maps.Keys(p.certExpiries)), ", "), + strings.Join(p.cfg.Domains, ", ")) + return CertStateMismatch + } + } + return CertStateValid } @@ -309,9 +303,9 @@ func (p *Provider) renewIfNeeded() error { switch p.certState() { case CertStateExpired: - logging.Info().Msg("certs expired, renewing") + log.Info().Msg("certs expired, renewing") case CertStateMismatch: - logging.Info().Msg("cert domains mismatch with config, renewing") + log.Info().Msg("cert domains mismatch with config, renewing") default: return nil } diff --git a/internal/autocert/provider_test/custom_test.go b/internal/autocert/provider_test/custom_test.go new file mode 100644 index 0000000..6490761 --- /dev/null +++ b/internal/autocert/provider_test/custom_test.go @@ -0,0 +1,453 @@ +//nolint:errchkjson,errcheck +package provider_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/yusing/go-proxy/internal/autocert" + "github.com/yusing/go-proxy/internal/dnsproviders" +) + +func TestMain(m *testing.M) { + dnsproviders.InitProviders() + m.Run() +} + +func TestCustomProvider(t *testing.T) { + t.Run("valid custom provider with step-ca", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "test@example.com", + Domains: []string{"example.com", "*.example.com"}, + Provider: autocert.ProviderCustom, + CADirURL: "https://ca.example.com:9000/acme/acme/directory", + CertPath: "certs/custom.crt", + KeyPath: "certs/custom.key", + ACMEKeyPath: "certs/custom-acme.key", + } + + err := cfg.Validate() + require.NoError(t, err) + + user, legoCfg, err := cfg.GetLegoConfig() + require.NoError(t, err) + require.NotNil(t, user) + require.NotNil(t, legoCfg) + require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", legoCfg.CADirURL) + require.Equal(t, "test@example.com", user.Email) + }) + + t.Run("custom provider missing CADirURL", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "test@example.com", + Domains: []string{"example.com"}, + Provider: autocert.ProviderCustom, + // CADirURL is missing + } + + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "missing field 'ca_dir_url'") + }) + + t.Run("custom provider with step-ca internal CA", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "admin@internal.com", + Domains: []string{"internal.example.com", "api.internal.example.com"}, + Provider: autocert.ProviderCustom, + CADirURL: "https://step-ca.internal:443/acme/acme/directory", + CertPath: "certs/internal.crt", + KeyPath: "certs/internal.key", + ACMEKeyPath: "certs/internal-acme.key", + } + + err := cfg.Validate() + require.NoError(t, err) + + user, legoCfg, err := cfg.GetLegoConfig() + require.NoError(t, err) + require.NotNil(t, user) + require.NotNil(t, legoCfg) + require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL) + require.Equal(t, "admin@internal.com", user.Email) + + provider := autocert.NewProvider(cfg, user, legoCfg) + require.NotNil(t, provider) + require.Equal(t, autocert.ProviderCustom, provider.GetName()) + require.Equal(t, "certs/internal.crt", provider.GetCertPath()) + require.Equal(t, "certs/internal.key", provider.GetKeyPath()) + }) +} + +func TestObtainCertFromCustomProvider(t *testing.T) { + // Create a test ACME server + acmeServer := newTestACMEServer(t) + defer acmeServer.Close() + + t.Run("obtain cert from custom step-ca server", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "test@example.com", + Domains: []string{"test.example.com"}, + Provider: autocert.ProviderCustom, + CADirURL: acmeServer.URL() + "/acme/acme/directory", + CertPath: "certs/stepca-test.crt", + KeyPath: "certs/stepca-test.key", + ACMEKeyPath: "certs/stepca-test-acme.key", + HTTPClient: acmeServer.httpClient(), + } + + err := error(cfg.Validate()) + require.NoError(t, err) + + user, legoCfg, err := cfg.GetLegoConfig() + require.NoError(t, err) + require.NotNil(t, user) + require.NotNil(t, legoCfg) + + provider := autocert.NewProvider(cfg, user, legoCfg) + require.NotNil(t, provider) + + // Test obtaining certificate + err = provider.ObtainCert() + require.NoError(t, err) + + // Verify certificate was obtained + cert, err := provider.GetCert(nil) + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify certificate properties + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, x509Cert.DNSNames, "test.example.com") + require.True(t, time.Now().Before(x509Cert.NotAfter)) + require.True(t, time.Now().After(x509Cert.NotBefore)) + }) +} + +// testACMEServer implements a minimal ACME server for testing. +type testACMEServer struct { + server *httptest.Server + caCert *x509.Certificate + caKey *rsa.PrivateKey + clientCSRs map[string]*x509.CertificateRequest + orderID string +} + +func newTestACMEServer(t *testing.T) *testACMEServer { + t.Helper() + + // Generate CA certificate and key + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test CA"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"Test"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + + caCert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + acme := &testACMEServer{ + caCert: caCert, + caKey: caKey, + clientCSRs: make(map[string]*x509.CertificateRequest), + orderID: "test-order-123", + } + + mux := http.NewServeMux() + acme.setupRoutes(mux) + + acme.server = httptest.NewTLSServer(mux) + return acme +} + +func (s *testACMEServer) Close() { + s.server.Close() +} + +func (s *testACMEServer) URL() string { + return s.server.URL +} + +func (s *testACMEServer) httpClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 30 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + }, + }, + } +} + +func (s *testACMEServer) setupRoutes(mux *http.ServeMux) { + // ACME directory endpoint + mux.HandleFunc("/acme/acme/directory", s.handleDirectory) + + // ACME endpoints + mux.HandleFunc("/acme/new-nonce", s.handleNewNonce) + mux.HandleFunc("/acme/new-account", s.handleNewAccount) + mux.HandleFunc("/acme/new-order", s.handleNewOrder) + mux.HandleFunc("/acme/authz/", s.handleAuthorization) + mux.HandleFunc("/acme/chall/", s.handleChallenge) + mux.HandleFunc("/acme/order/", s.handleOrder) + mux.HandleFunc("/acme/cert/", s.handleCertificate) +} + +func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) { + directory := map[string]interface{}{ + "newNonce": s.server.URL + "/acme/new-nonce", + "newAccount": s.server.URL + "/acme/new-account", + "newOrder": s.server.URL + "/acme/new-order", + "keyChange": s.server.URL + "/acme/key-change", + "meta": map[string]interface{}{ + "termsOfService": s.server.URL + "/terms", + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(directory) +} + +func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "test-nonce-12345") + w.WriteHeader(http.StatusOK) +} + +func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) { + account := map[string]interface{}{ + "status": "valid", + "contact": []string{"mailto:test@example.com"}, + "orders": s.server.URL + "/acme/orders", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", s.server.URL+"/acme/account/1") + w.Header().Set("Replay-Nonce", "test-nonce-67890") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(account) +} + +func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) { + authzID := "test-authz-456" + + order := map[string]interface{}{ + "status": "ready", // Skip pending state for simplicity + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, + "authorizations": []string{s.server.URL + "/acme/authz/" + authzID}, + "finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID) + w.Header().Set("Replay-Nonce", "test-nonce-order") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(order) +} + +func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) { + authz := map[string]interface{}{ + "status": "valid", // Skip challenge validation for simplicity + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifier": map[string]string{"type": "dns", "value": "test.example.com"}, + "challenges": []map[string]interface{}{ + { + "type": "dns-01", + "status": "valid", + "url": s.server.URL + "/acme/chall/test-chall-789", + "token": "test-token-abc123", + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-authz") + json.NewEncoder(w).Encode(authz) +} + +func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) { + challenge := map[string]interface{}{ + "type": "dns-01", + "status": "valid", + "url": r.URL.String(), + "token": "test-token-abc123", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-chall") + json.NewEncoder(w).Encode(challenge) +} + +func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/finalize") { + s.handleFinalize(w, r) + return + } + + certURL := s.server.URL + "/acme/cert/" + s.orderID + order := map[string]interface{}{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, + "certificate": certURL, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-order-get") + json.NewEncoder(w).Encode(order) +} + +func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) { + // Read the JWS payload + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request", http.StatusBadRequest) + return + } + + // Extract CSR from JWS payload + csr, err := s.extractCSRFromJWS(body) + if err != nil { + http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest) + return + } + + // Store the CSR for certificate generation + s.clientCSRs[s.orderID] = csr + + certURL := s.server.URL + "/acme/cert/" + s.orderID + order := map[string]interface{}{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, + "certificate": certURL, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize")) + w.Header().Set("Replay-Nonce", "test-nonce-finalize") + json.NewEncoder(w).Encode(order) +} + +func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) { + // Parse the JWS structure + var jws struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Signature string `json:"signature"` + } + + if err := json.Unmarshal(jwsData, &jws); err != nil { + return nil, err + } + + // Decode the payload + payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return nil, err + } + + // Parse the finalize request + var finalizeReq struct { + CSR string `json:"csr"` + } + + if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil { + return nil, err + } + + // Decode the CSR + csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR) + if err != nil { + return nil, err + } + + // Parse the CSR + csr, err := x509.ParseCertificateRequest(csrBytes) + if err != nil { + return nil, err + } + + return csr, nil +} + +func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) { + // Extract order ID from URL + orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/") + + // Get the CSR for this order + csr, exists := s.clientCSRs[orderID] + if !exists { + http.Error(w, "No CSR found for order", http.StatusBadRequest) + return + } + + // Create certificate using the public key from the client's CSR + template := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Cert"}, + Country: []string{"US"}, + }, + DNSNames: csr.DNSNames, + NotBefore: time.Now(), + NotAfter: time.Now().Add(90 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Use the public key from the CSR and sign with CA key + certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Return certificate chain + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw}) + + w.Header().Set("Content-Type", "application/pem-certificate-chain") + w.Header().Set("Replay-Nonce", "test-nonce-cert") + w.Write(append(certPEM, caPEM...)) +} diff --git a/internal/autocert/provider_test/ovh_test.go b/internal/autocert/provider_test/ovh_test.go index 203e9ef..e268af4 100644 --- a/internal/autocert/provider_test/ovh_test.go +++ b/internal/autocert/provider_test/ovh_test.go @@ -6,7 +6,7 @@ import ( "github.com/go-acme/lego/v4/providers/dns/ovh" "github.com/goccy/go-yaml" "github.com/stretchr/testify/require" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) // type Config struct { @@ -45,6 +45,6 @@ oauth2_config: testYaml = testYaml[1:] // remove first \n opt := make(map[string]any) require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt)) - require.NoError(t, utils.MapUnmarshalValidate(opt, cfg)) + require.NoError(t, serialization.MapUnmarshalValidate(opt, cfg)) require.Equal(t, cfgExpected, cfg) } diff --git a/internal/autocert/providers.go b/internal/autocert/providers.go index dbbcb6d..5b73e39 100644 --- a/internal/autocert/providers.go +++ b/internal/autocert/providers.go @@ -3,7 +3,7 @@ package autocert import ( "github.com/go-acme/lego/v4/challenge" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type Generator func(map[string]any) (challenge.Provider, gperr.Error) @@ -16,9 +16,11 @@ func DNSProvider[CT any, PT challenge.Provider]( ) Generator { return func(opt map[string]any) (challenge.Provider, gperr.Error) { cfg := defaultCfg() - err := utils.MapUnmarshalValidate(opt, &cfg) - if err != nil { - return nil, err + if len(opt) > 0 { + err := serialization.MapUnmarshalValidate(opt, &cfg) + if err != nil { + return nil, err + } } p, pErr := newProvider(cfg) return p, gperr.Wrap(pErr) diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index b436bed..62c589e 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -4,7 +4,7 @@ import ( "errors" "os" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -13,14 +13,14 @@ func (p *Provider) Setup() (err error) { if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist return err } - logging.Debug().Msg("obtaining cert due to error loading cert") + log.Debug().Msg("obtaining cert due to error loading cert") if err = p.ObtainCert(); err != nil { return err } } for _, expiry := range p.GetExpiries() { - logging.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) + log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) break } diff --git a/internal/config/config.go b/internal/config/config.go index 621cfa0..02cac28 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,20 +10,20 @@ import ( "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/api" autocert "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/entrypoint" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/maxmind" "github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/proxmox" proxy "github.com/yusing/go-proxy/internal/route/provider" + "github.com/yusing/go-proxy/internal/serialization" "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/strutils/ansi" "github.com/yusing/go-proxy/internal/watcher" @@ -96,10 +96,10 @@ func OnConfigChange(ev []events.Event) { // just reload once and check the last event switch ev[len(ev)-1].Action { case events.ActionFileRenamed: - logging.Warn().Msg(cfgRenameWarn) + log.Warn().Msg(cfgRenameWarn) return case events.ActionFileDeleted: - logging.Warn().Msg(cfgDeleteWarn) + log.Warn().Msg(cfgDeleteWarn) return } @@ -161,7 +161,7 @@ func (cfg *Config) Start(opts ...*StartServersOptions) { func (cfg *Config) StartAutoCert() { autocert := cfg.autocertProvider if autocert == nil { - logging.Info().Msg("autocert not configured") + log.Info().Msg("autocert not configured") return } @@ -223,7 +223,7 @@ func (cfg *Config) load() gperr.Error { } model := config.DefaultConfig() - if err := utils.UnmarshalValidateYAML(data, model); err != nil { + if err := serialization.UnmarshalValidateYAML(data, model); err != nil { gperr.LogFatal(errMsg, err) } @@ -374,6 +374,6 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error { } results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes()) }) - logging.Info().Msg(results.String()) + log.Info().Msg(results.String()) return errs.Error() } diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 4460025..80fa48e 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -14,7 +14,7 @@ import ( maxmind "github.com/yusing/go-proxy/internal/maxmind/types" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/proxmox" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type ( @@ -93,14 +93,14 @@ func HasInstance() bool { func Validate(data []byte) gperr.Error { var model Config - return utils.UnmarshalValidateYAML(data, &model) + return serialization.UnmarshalValidateYAML(data, &model) } var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`) func init() { - utils.RegisterDefaultValueFactory(DefaultConfig) - utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool { + serialization.RegisterDefaultValueFactory(DefaultConfig) + serialization.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool { domains := fl.Field().Interface().([]string) for _, domain := range domains { if !matchDomainsRegex.MatchString(domain) { @@ -109,7 +109,7 @@ func init() { } return true }) - utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool { + serialization.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool { m := fl.Field().Interface().(map[string]string) for k := range m { if k == "" { diff --git a/internal/dnsproviders/go.mod b/internal/dnsproviders/go.mod index 5506b95..65efe61 100644 --- a/internal/dnsproviders/go.mod +++ b/internal/dnsproviders/go.mod @@ -4,6 +4,8 @@ go 1.24.3 replace github.com/yusing/go-proxy => ../.. +replace github.com/yusing/go-proxy/internal/utils => ../utils + require ( github.com/go-acme/lego/v4 v4.23.1 github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000 @@ -156,6 +158,7 @@ require ( github.com/vultr/govultr/v3 v3.20.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + github.com/yusing/go-proxy/internal/utils v0.0.0 // indirect go.mongodb.org/mongo-driver v1.17.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect diff --git a/internal/docker/client.go b/internal/docker/client.go index afadf20..6d15d88 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -13,13 +13,14 @@ import ( "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" ) +// TODO: implement reconnect here. type ( SharedClient struct { *client.Client @@ -83,7 +84,7 @@ func closeTimedOutClients() { if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs { delete(clientMap, c.Key()) c.Client.Close() - logging.Debug().Str("host", c.DaemonHost()).Msg("docker client closed") + log.Debug().Str("host", c.DaemonHost()).Msg("docker client closed") } } } @@ -148,7 +149,7 @@ func NewClient(host string) (*SharedClient, error) { default: helper, err := connhelper.GetConnectionHelper(host) if err != nil { - logging.Panic().Err(err).Msg("failed to get connection helper") + log.Panic().Err(err).Msg("failed to get connection helper") } if helper != nil { httpClient := &http.Client{ @@ -189,10 +190,10 @@ func NewClient(host string) (*SharedClient, error) { c.dial = client.Dialer() } if c.addr == "" { - c.addr = c.Client.DaemonHost() + c.addr = c.DaemonHost() } - defer logging.Debug().Str("host", host).Msg("docker client initialized") + defer log.Debug().Str("host", host).Msg("docker client initialized") clientMap[c.Key()] = c return c, nil diff --git a/internal/docker/container.go b/internal/docker/container.go index 6044b30..ddc4ac6 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -9,11 +9,12 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/go-connections/nat" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/agent/pkg/agent" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/gperr" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" - "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/serialization" "github.com/yusing/go-proxy/internal/utils" ) @@ -90,7 +91,7 @@ func FromDocker(c *container.SummaryTrimmed, dockerHost string) (res *Container) var ok bool res.Agent, ok = config.GetInstance().GetAgent(dockerHost) if !ok { - logging.Error().Msgf("agent %q not found", dockerHost) + log.Error().Msgf("agent %q not found", dockerHost) } } @@ -183,7 +184,7 @@ func (c *Container) setPublicHostname() { } url, err := url.Parse(c.DockerHost) if err != nil { - logging.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost) + log.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost) c.PublicHostname = "127.0.0.1" return } @@ -224,7 +225,7 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) { ContainerName: c.ContainerName, }, } - err := utils.MapUnmarshalValidate(cfg, idwCfg) + err := serialization.MapUnmarshalValidate(cfg, idwCfg) if err != nil { gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err)) } else { diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index e3683db..aa35b9c 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -6,7 +6,7 @@ import ( "net/http" "strings" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/logging/accesslog" gphttp "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/middleware" @@ -50,7 +50,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error { } ep.middleware = mid - logging.Debug().Msg("entrypoint middleware loaded") + log.Debug().Msg("entrypoint middleware loaded") return nil } @@ -64,7 +64,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request if err != nil { return } - logging.Debug().Msg("entrypoint access logger created") + log.Debug().Msg("entrypoint access logger created") return } @@ -89,7 +89,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Then scraper / scanners will know the subdomain is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. if served := middleware.ServeStaticErrorPageFile(w, r); !served { - logging.Err(err). + log.Err(err). Str("method", r.Method). Str("url", r.URL.String()). Str("remote", r.RemoteAddr). @@ -99,7 +99,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "text/html; charset=utf-8") if _, err := w.Write(errorPage); err != nil { - logging.Err(err).Msg("failed to write error page") + log.Err(err).Msg("failed to write error page") } } else { http.Error(w, err.Error(), http.StatusNotFound) diff --git a/internal/gperr/log.go b/internal/gperr/log.go index 5be1bce..a13336b 100644 --- a/internal/gperr/log.go +++ b/internal/gperr/log.go @@ -4,8 +4,8 @@ import ( "os" "github.com/rs/zerolog" + zerologlog "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/logging" ) func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) { @@ -13,7 +13,7 @@ func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) if len(logger) > 0 { l = logger[0] } else { - l = logging.GetLogger() + l = &zerologlog.Logger } l.WithLevel(level).Msg(New(highlightANSI(msg)).With(err).Error()) switch level { diff --git a/internal/homepage/homepage.go b/internal/homepage/homepage.go index 2d9bcc3..13d78bb 100644 --- a/internal/homepage/homepage.go +++ b/internal/homepage/homepage.go @@ -5,7 +5,7 @@ import ( "slices" "github.com/yusing/go-proxy/internal/homepage/widgets" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type ( @@ -32,7 +32,7 @@ type ( ) func init() { - utils.RegisterDefaultValueFactory(func() *ItemConfig { + serialization.RegisterDefaultValueFactory(func() *ItemConfig { return &ItemConfig{ Show: true, } diff --git a/internal/homepage/icon_cache.go b/internal/homepage/icon_cache.go index 4f0f658..2a4f5c2 100644 --- a/internal/homepage/icon_cache.go +++ b/internal/homepage/icon_cache.go @@ -6,9 +6,9 @@ import ( "sync" "time" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/jsonstore" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/atomic" @@ -74,7 +74,7 @@ func pruneExpiredIconCache() { } } if nPruned > 0 { - logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache") + log.Info().Int("pruned", nPruned).Msg("pruned expired icon cache") } } @@ -87,7 +87,7 @@ func loadIconCache(key string) *FetchResult { defer iconMu.RUnlock() icon, ok := iconCache.Load(key) if ok && len(icon.Icon) > 0 { - logging.Debug(). + log.Debug(). Str("key", key). Msg("icon found in cache") icon.LastAccess.Store(utils.TimeNow()) @@ -99,7 +99,7 @@ func loadIconCache(key string) *FetchResult { func storeIconCache(key string, result *FetchResult) { icon := result.Icon if len(icon) > maxIconSize { - logging.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size") + log.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size") return } @@ -109,7 +109,7 @@ func storeIconCache(key string, result *FetchResult) { entry := &cacheEntry{Icon: icon, ContentType: result.contentType} entry.LastAccess.Store(time.Now()) iconCache.Store(key, entry) - logging.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache") + log.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache") } func (e *cacheEntry) IsExpired() bool { diff --git a/internal/homepage/list_icons.go b/internal/homepage/list_icons.go index 734282d..72497a8 100644 --- a/internal/homepage/list_icons.go +++ b/internal/homepage/list_icons.go @@ -1,6 +1,7 @@ package homepage import ( + "context" "encoding/json" "fmt" "io" @@ -10,10 +11,10 @@ import ( "time" "github.com/lithammer/fuzzysearch/fuzzy" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/serialization" "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -46,30 +47,30 @@ type ( func (icon *IconMeta) Filenames(ref string) []string { filenames := make([]string, 0) if icon.SVG { - filenames = append(filenames, fmt.Sprintf("%s.svg", ref)) + filenames = append(filenames, ref+".svg") if icon.Light { - filenames = append(filenames, fmt.Sprintf("%s-light.svg", ref)) + filenames = append(filenames, ref+"-light.svg") } if icon.Dark { - filenames = append(filenames, fmt.Sprintf("%s-dark.svg", ref)) + filenames = append(filenames, ref+"-dark.svg") } } if icon.PNG { - filenames = append(filenames, fmt.Sprintf("%s.png", ref)) + filenames = append(filenames, ref+".png") if icon.Light { - filenames = append(filenames, fmt.Sprintf("%s-light.png", ref)) + filenames = append(filenames, ref+"-light.png") } if icon.Dark { - filenames = append(filenames, fmt.Sprintf("%s-dark.png", ref)) + filenames = append(filenames, ref+"-dark.png") } } if icon.WebP { - filenames = append(filenames, fmt.Sprintf("%s.webp", ref)) + filenames = append(filenames, ref+".webp") if icon.Light { - filenames = append(filenames, fmt.Sprintf("%s-light.webp", ref)) + filenames = append(filenames, ref+"-light.webp") } if icon.Dark { - filenames = append(filenames, fmt.Sprintf("%s-dark.webp", ref)) + filenames = append(filenames, ref+"-dark.webp") } } return filenames @@ -99,21 +100,21 @@ func InitIconListCache() { iconsCache.Lock() defer iconsCache.Unlock() - err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache) + err := serialization.LoadJSONIfExist(common.IconListCachePath, iconsCache) if err != nil { - logging.Error().Err(err).Msg("failed to load icons") + log.Error().Err(err).Msg("failed to load icons") } else if len(iconsCache.Icons) > 0 { - logging.Info(). + log.Info(). Int("icons", len(iconsCache.Icons)). Msg("icons loaded") } if err = updateIcons(); err != nil { - logging.Error().Err(err).Msg("failed to update icons") + log.Error().Err(err).Msg("failed to update icons") } task.OnProgramExit("save_icons_cache", func() { - utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644) + _ = serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644) }) } @@ -134,17 +135,17 @@ func ListAvailableIcons() (*Cache, error) { iconsCache.Lock() defer iconsCache.Unlock() - logging.Info().Msg("updating icon data") + log.Info().Msg("updating icon data") if err := updateIcons(); err != nil { return nil, err } - logging.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated") + log.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated") iconsCache.LastUpdate = time.Now() - err := utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644) + err := serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644) if err != nil { - logging.Warn().Err(err).Msg("failed to save icons") + log.Warn().Err(err).Msg("failed to save icons") } return iconsCache, nil } @@ -230,14 +231,17 @@ func updateIcons() error { var httpGet = httpGetImpl -func MockHttpGet(body []byte) { +func MockHTTPGet(body []byte) { httpGet = func(_ string) ([]byte, error) { return body, nil } } func httpGetImpl(url string) ([]byte, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } @@ -347,7 +351,7 @@ func UpdateSelfhstIcons() error { } data := make([]SelfhStIcon, 0) - err = json.Unmarshal(body, &data) + err = json.Unmarshal(body, &data) //nolint:musttag if err != nil { return err } diff --git a/internal/homepage/list_icons_test.go b/internal/homepage/list_icons_test.go index 296c635..9569233 100644 --- a/internal/homepage/list_icons_test.go +++ b/internal/homepage/list_icons_test.go @@ -68,6 +68,8 @@ type testCases struct { } func runTests(t *testing.T, iconsCache *Cache, test []testCases) { + t.Helper() + for _, item := range test { icon, ok := iconsCache.Icons[item.Key] if !ok { @@ -89,7 +91,7 @@ func runTests(t *testing.T, iconsCache *Cache, test []testCases) { } func TestListWalkxCodeIcons(t *testing.T) { - MockHttpGet([]byte(walkxcodeIcons)) + MockHTTPGet([]byte(walkxcodeIcons)) if err := UpdateWalkxCodeIcons(); err != nil { t.Fatal(err) } @@ -122,7 +124,7 @@ func TestListWalkxCodeIcons(t *testing.T) { } func TestListSelfhstIcons(t *testing.T) { - MockHttpGet([]byte(selfhstIcons)) + MockHTTPGet([]byte(selfhstIcons)) if err := UpdateSelfhstIcons(); err != nil { t.Fatal(err) } diff --git a/internal/homepage/widgets/widgets.go b/internal/homepage/widgets/widgets.go index e653b69..6d62eb4 100644 --- a/internal/homepage/widgets/widgets.go +++ b/internal/homepage/widgets/widgets.go @@ -4,7 +4,7 @@ import ( "context" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type ( @@ -33,17 +33,18 @@ var widgetProviders = map[string]struct{}{ var ErrInvalidProvider = gperr.New("invalid provider") func (cfg *Config) UnmarshalMap(m map[string]any) error { - cfg.Provider = m["provider"].(string) + var ok bool + cfg.Provider, ok = m["provider"].(string) + if !ok { + return ErrInvalidProvider.Withf("non string") + } if _, ok := widgetProviders[cfg.Provider]; !ok { return ErrInvalidProvider.Subject(cfg.Provider) } delete(m, "provider") - m, ok := m["config"].(map[string]any) + m, ok = m["config"].(map[string]any) if !ok { return gperr.New("invalid config") } - if err := utils.MapUnmarshalValidate(m, &cfg.Config); err != nil { - return err - } - return nil + return serialization.MapUnmarshalValidate(m, &cfg.Config) } diff --git a/internal/idlewatcher/watcher.go b/internal/idlewatcher/watcher.go index 02f84e1..4cf1139 100644 --- a/internal/idlewatcher/watcher.go +++ b/internal/idlewatcher/watcher.go @@ -7,10 +7,10 @@ import ( "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/idlewatcher/provider" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" net "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/route/routes" @@ -73,13 +73,13 @@ var dummyHealthCheckConfig = &health.HealthCheckConfig{ } var ( - causeReload = gperr.New("reloaded") - causeContainerDestroy = gperr.New("container destroyed") + causeReload = gperr.New("reloaded") //nolint:errname + causeContainerDestroy = gperr.New("container destroyed") //nolint:errname ) const reqTimeout = 3 * time.Second -// TODO: fix stream type +// TODO: fix stream type. func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) { cfg := r.IdlewatcherConfig() key := cfg.Key() @@ -120,7 +120,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) { return nil, err } w.provider = p - w.l = logging.With(). + w.l = log.With(). Str("provider", providerType). Str("container", cfg.ContainerName()). Logger() diff --git a/internal/jsonstore/jsonstore.go b/internal/jsonstore/jsonstore.go index 54ff2ed..8fb0df2 100644 --- a/internal/jsonstore/jsonstore.go +++ b/internal/jsonstore/jsonstore.go @@ -2,18 +2,17 @@ package jsonstore import ( "encoding/json" + "maps" "os" "path/filepath" "reflect" - "maps" - "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/serialization" "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils" ) type namespace string @@ -36,13 +35,15 @@ type store interface { json.Unmarshaler } -var stores = make(map[namespace]store) -var storesPath = common.DataDir +var ( + stores = make(map[namespace]store) + storesPath = common.DataDir +) func init() { task.OnProgramExit("save_stores", func() { if err := save(); err != nil { - logging.Error().Err(err).Msg("failed to save stores") + log.Error().Err(err).Msg("failed to save stores") } }) } @@ -54,20 +55,20 @@ func loadNS[T store](ns namespace) T { file, err := os.Open(path) if err != nil { if !os.IsNotExist(err) { - logging.Err(err). + log.Err(err). Str("path", path). Msg("failed to load store") } } else { defer file.Close() if err := json.NewDecoder(file).Decode(&store); err != nil { - logging.Err(err). + log.Err(err). Str("path", path). Msg("failed to load store") } } stores[ns] = store - logging.Debug(). + log.Debug(). Str("namespace", string(ns)). Str("path", path). Msg("loaded store") @@ -77,7 +78,7 @@ func loadNS[T store](ns namespace) T { func save() error { errs := gperr.NewBuilder("failed to save data stores") for ns, store := range stores { - if err := utils.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil { + if err := serialization.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil { errs.Add(err) } } @@ -86,7 +87,7 @@ func save() error { func Store[VT any](namespace namespace) MapStore[VT] { if _, ok := stores[namespace]; ok { - logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists") + log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists") } store := loadNS[*MapStore[VT]](namespace) stores[namespace] = store @@ -95,7 +96,7 @@ func Store[VT any](namespace namespace) MapStore[VT] { func Object[Ptr Initializer](namespace namespace) Ptr { if _, ok := stores[namespace]; ok { - logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists") + log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists") } obj := loadNS[*ObjectStore[Ptr]](namespace) stores[namespace] = obj @@ -117,7 +118,7 @@ func (s *MapStore[VT]) UnmarshalJSON(data []byte) error { } s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp))) for k, v := range tmp { - s.Map.Store(k, v) + s.Store(k, v) } return nil } diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go index 8d6a83e..647181f 100644 --- a/internal/logging/accesslog/access_logger.go +++ b/internal/logging/accesslog/access_logger.go @@ -8,8 +8,8 @@ import ( "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" maxmind "github.com/yusing/go-proxy/internal/maxmind/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils" @@ -83,6 +83,9 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) { if err != nil { return nil, err } + if io == nil { + return nil, nil //nolint:nilnil + } return NewAccessLoggerWithIO(parent, io, cfg), nil } @@ -120,7 +123,7 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any bufSize: MinBufferSize, lineBufPool: synk.NewBytesPool(), errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst), - logger: logging.With().Str("file", writer.Name()).Logger(), + logger: log.With().Str("file", writer.Name()).Logger(), } l.supportRotate = unwrap[supportRotate](writer) @@ -181,7 +184,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) { func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) { line := l.lineBufPool.Get() defer l.lineBufPool.Put(line) - line = l.ACLFormatter.AppendACLLog(line, info, blocked) + line = l.AppendACLLog(line, info, blocked) if line[len(line)-1] != '\n' { line = append(line, '\n') } @@ -194,7 +197,7 @@ func (l *AccessLogger) ShouldRotate() bool { func (l *AccessLogger) Rotate() (result *RotateResult, err error) { if !l.ShouldRotate() { - return nil, nil + return nil, nil //nolint:nilnil } l.writer.Flush() diff --git a/internal/logging/accesslog/config.go b/internal/logging/accesslog/config.go index e9b90d5..9ac0c5d 100644 --- a/internal/logging/accesslog/config.go +++ b/internal/logging/accesslog/config.go @@ -4,7 +4,7 @@ import ( "time" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type ( @@ -126,6 +126,6 @@ func DefaultACLLoggerConfig() *ACLLoggerConfig { } func init() { - utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig) - utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig) + serialization.RegisterDefaultValueFactory(DefaultRequestLoggerConfig) + serialization.RegisterDefaultValueFactory(DefaultACLLoggerConfig) } diff --git a/internal/logging/accesslog/config_test.go b/internal/logging/accesslog/config_test.go index a44199d..37ccc54 100644 --- a/internal/logging/accesslog/config_test.go +++ b/internal/logging/accesslog/config_test.go @@ -5,7 +5,7 @@ import ( "github.com/yusing/go-proxy/internal/docker" . "github.com/yusing/go-proxy/internal/logging/accesslog" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" expect "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) { expect.NoError(t, err) var config RequestLoggerConfig - err = utils.MapUnmarshalValidate(parsed, &config) + err = serialization.MapUnmarshalValidate(parsed, &config) expect.NoError(t, err) expect.Equal(t, config.Format, FormatCombined) diff --git a/internal/logging/accesslog/file_logger.go b/internal/logging/accesslog/file_logger.go index 7a85662..36e791c 100644 --- a/internal/logging/accesslog/file_logger.go +++ b/internal/logging/accesslog/file_logger.go @@ -7,7 +7,7 @@ import ( "path/filepath" "sync" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/utils" ) @@ -35,20 +35,19 @@ func newFileIO(path string) (SupportRotate, error) { if opened, ok := openedFiles[path]; ok { opened.refCount.Add() return opened, nil - } else { - // cannot open as O_APPEND as we need Seek and WriteAt - f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644) - if err != nil { - return nil, fmt.Errorf("access log open error: %w", err) - } - if _, err := f.Seek(0, io.SeekEnd); err != nil { - return nil, fmt.Errorf("access log seek error: %w", err) - } - file = &File{f: f, path: path, refCount: utils.NewRefCounter()} - openedFiles[path] = file - go file.closeOnZero() } + // cannot open as O_APPEND as we need Seek and WriteAt + f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + return nil, fmt.Errorf("access log open error: %w", err) + } + if _, err := f.Seek(0, io.SeekEnd); err != nil { + return nil, fmt.Errorf("access log seek error: %w", err) + } + file = &File{f: f, path: path, refCount: utils.NewRefCounter()} + openedFiles[path] = file + go file.closeOnZero() return file, nil } @@ -90,7 +89,7 @@ func (f *File) Close() error { } func (f *File) closeOnZero() { - defer logging.Debug(). + defer log.Debug(). Str("path", f.path). Msg("access log closed") diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 375d4dd..eae974e 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -1,4 +1,3 @@ -//nolint:zerologlint package logging import ( @@ -10,6 +9,8 @@ import ( "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/utils/strutils" + + zerologlog "github.com/rs/zerolog/log" ) var ( @@ -61,22 +62,6 @@ func InitLogger(out ...io.Writer) { log.SetOutput(writer) log.SetPrefix("") log.SetFlags(0) + zerolog.TimeFieldFormat = timeFmt + zerologlog.Logger = logger } - -func DiscardLogger() { zerolog.SetGlobalLevel(zerolog.Disabled) } - -func AddHook(h zerolog.Hook) { logger = logger.Hook(h) } - -func GetLogger() *zerolog.Logger { return &logger } -func With() zerolog.Context { return logger.With() } - -func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) } - -func Info() *zerolog.Event { return logger.Info() } -func Warn() *zerolog.Event { return logger.Warn() } -func Error() *zerolog.Event { return logger.Error() } -func Err(err error) *zerolog.Event { return logger.Err(err) } -func Debug() *zerolog.Event { return logger.Debug() } -func Fatal() *zerolog.Event { return logger.Fatal() } -func Panic() *zerolog.Event { return logger.Panic() } -func Trace() *zerolog.Event { return logger.Trace() } diff --git a/internal/maxmind/types/config.go b/internal/maxmind/types/config.go index 298f3a5..a2bb526 100644 --- a/internal/maxmind/types/config.go +++ b/internal/maxmind/types/config.go @@ -2,8 +2,8 @@ package maxmind import ( "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" ) type ( @@ -28,6 +28,6 @@ func (cfg *Config) Validate() gperr.Error { } func (cfg *Config) Logger() *zerolog.Logger { - l := logging.With().Str("database", string(cfg.Database)).Logger() + l := log.With().Str("database", string(cfg.Database)).Logger() return &l } diff --git a/internal/metrics/period/poller.go b/internal/metrics/period/poller.go index 9422f7a..b068c8f 100644 --- a/internal/metrics/period/poller.go +++ b/internal/metrics/period/poller.go @@ -10,8 +10,8 @@ import ( "sync" "time" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/atomic" ) @@ -47,7 +47,7 @@ var initDataDirOnce sync.Once func initDataDir() { if err := os.MkdirAll(saveBaseDir, 0o755); err != nil { - logging.Error().Err(err).Msg("failed to create metrics data directory") + log.Error().Err(err).Msg("failed to create metrics data directory") } } @@ -65,7 +65,7 @@ func NewPoller[T any, AggregateT json.Marshaler]( } func (p *Poller[T, AggregateT]) savePath() string { - return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name)) + return filepath.Join(saveBaseDir, p.name+".json") } func (p *Poller[T, AggregateT]) load() error { @@ -135,13 +135,14 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) { func (p *Poller[T, AggregateT]) Start() { t := task.RootTask("poller." + p.name) + l := log.With().Str("name", p.name).Logger() err := p.load() if err != nil { if !os.IsNotExist(err) { - logging.Error().Err(err).Msgf("failed to load last metrics data for %s", p.name) + l.Err(err).Msg("failed to load last metrics data") } } else { - logging.Debug().Msgf("Loaded last metrics data for %s, %d entries", p.name, p.period.Total()) + l.Debug().Int("entries", p.period.Total()).Msgf("Loaded last metrics data") } go func() { @@ -154,11 +155,13 @@ func (p *Poller[T, AggregateT]) Start() { gatherErrsTicker.Stop() saveTicker.Stop() - p.save() + if err := p.save(); err != nil { + l.Err(err).Msg("failed to save metrics data") + } t.Finish(nil) }() - logging.Debug().Msgf("Starting poller %s with interval %s", p.name, pollInterval) + l.Debug().Dur("interval", pollInterval).Msg("Starting poller") p.pollWithTimeout(t.Context()) @@ -176,7 +179,7 @@ func (p *Poller[T, AggregateT]) Start() { case <-gatherErrsTicker.C: errs, ok := p.gatherErrs() if ok { - logging.Error().Msg(errs) + log.Error().Msg(errs) } p.clearErrs() } diff --git a/internal/metrics/systeminfo/system_info.go b/internal/metrics/systeminfo/system_info.go index c76d165..05e317e 100644 --- a/internal/metrics/systeminfo/system_info.go +++ b/internal/metrics/systeminfo/system_info.go @@ -8,6 +8,7 @@ import ( "syscall" "time" + "github.com/rs/zerolog/log" "github.com/shirou/gopsutil/v4/cpu" "github.com/shirou/gopsutil/v4/disk" "github.com/shirou/gopsutil/v4/mem" @@ -16,7 +17,6 @@ import ( "github.com/shirou/gopsutil/v4/warning" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/metrics/period" ) @@ -130,7 +130,7 @@ func getSystemInfo(ctx context.Context, lastResult *SystemInfo) (*SystemInfo, er } }) if allWarnings.HasError() { - logging.Warn().Msg(allWarnings.String()) + log.Warn().Msg(allWarnings.String()) } if allErrors.HasError() { return nil, allErrors.Error() @@ -195,7 +195,7 @@ func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInf if len(s.Disks) == 0 { return errs.Error() } - logging.Warn().Msg(errs.String()) + log.Warn().Msg(errs.String()) } return nil } diff --git a/internal/net/gphttp/body.go b/internal/net/gphttp/body.go index b2d0173..be0bf51 100644 --- a/internal/net/gphttp/body.go +++ b/internal/net/gphttp/body.go @@ -6,7 +6,7 @@ import ( "errors" "net/http" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" ) func WriteBody(w http.ResponseWriter, body []byte) { @@ -14,9 +14,9 @@ func WriteBody(w http.ResponseWriter, body []byte) { switch { case errors.Is(err, http.ErrHandlerTimeout), errors.Is(err, context.DeadlineExceeded): - logging.Err(err).Msg("timeout writing body") + log.Err(err).Msg("timeout writing body") default: - logging.Err(err).Msg("failed to write body") + log.Err(err).Msg("failed to write body") } } } diff --git a/internal/net/gphttp/gpwebsocket/utils.go b/internal/net/gphttp/gpwebsocket/utils.go index 0a0d229..c5bd3cc 100644 --- a/internal/net/gphttp/gpwebsocket/utils.go +++ b/internal/net/gphttp/gpwebsocket/utils.go @@ -9,11 +9,11 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" ) func warnNoMatchDomains() { - logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins") + log.Warn().Msg("no match domains configured, accepting websocket API request from all origins") } var warnNoMatchDomainOnce sync.Once diff --git a/internal/net/gphttp/loadbalancer/loadbalancer.go b/internal/net/gphttp/loadbalancer/loadbalancer.go index d36bf98..90b795c 100644 --- a/internal/net/gphttp/loadbalancer/loadbalancer.go +++ b/internal/net/gphttp/loadbalancer/loadbalancer.go @@ -7,8 +7,8 @@ import ( "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types" "github.com/yusing/go-proxy/internal/task" @@ -47,7 +47,7 @@ func New(cfg *Config) *LoadBalancer { lb := &LoadBalancer{ Config: new(Config), pool: pool.New[Server]("loadbalancer." + cfg.Link), - l: logging.With().Str("name", cfg.Link).Logger(), + l: log.With().Str("name", cfg.Link).Logger(), } lb.UpdateConfigIfNeeded(cfg) return lb diff --git a/internal/net/gphttp/loadbalancer/types/weight.go b/internal/net/gphttp/loadbalancer/types/weight.go index 2bf7d84..2339a27 100644 --- a/internal/net/gphttp/loadbalancer/types/weight.go +++ b/internal/net/gphttp/loadbalancer/types/weight.go @@ -1,3 +1,3 @@ package types -type Weight uint16 +type Weight int diff --git a/internal/net/gphttp/logging.go b/internal/net/gphttp/logging.go index cfb67f0..9db2f35 100644 --- a/internal/net/gphttp/logging.go +++ b/internal/net/gphttp/logging.go @@ -4,14 +4,14 @@ import ( "net/http" "github.com/rs/zerolog" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" ) func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { - return logging.WithLevel(level). - Str("remote", r.RemoteAddr). - Str("host", r.Host). - Str("uri", r.Method+" "+r.RequestURI) + return log.WithLevel(level). //nolint:zerologlint + Str("remote", r.RemoteAddr). + Str("host", r.Host). + Str("uri", r.Method+" "+r.RequestURI) } func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) } diff --git a/internal/net/gphttp/middleware/captcha/middleware.go b/internal/net/gphttp/middleware/captcha/middleware.go index 91da59c..c4769ab 100644 --- a/internal/net/gphttp/middleware/captcha/middleware.go +++ b/internal/net/gphttp/middleware/captcha/middleware.go @@ -4,8 +4,8 @@ import ( "net/http" "text/template" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/auth" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp" _ "embed" @@ -55,7 +55,7 @@ func PreRequest(p Provider, w http.ResponseWriter, r *http.Request) (proceed boo "FormHTML": p.FormHTML(), }) if err != nil { - logging.Error().Err(err).Msg("failed to execute captcha page") + log.Error().Err(err).Msg("failed to execute captcha page") } return false } diff --git a/internal/net/gphttp/middleware/cidr_whitelist.go b/internal/net/gphttp/middleware/cidr_whitelist.go index 6b9271f..291b45f 100644 --- a/internal/net/gphttp/middleware/cidr_whitelist.go +++ b/internal/net/gphttp/middleware/cidr_whitelist.go @@ -7,7 +7,7 @@ import ( "github.com/go-playground/validator/v10" gphttp "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/types" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -34,7 +34,7 @@ var ( ) func init() { - utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool { + serialization.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool { statusCode := fl.Field().Int() return gphttp.IsStatusCodeValid(int(statusCode)) }) @@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool { ipStr = r.RemoteAddr } ip := net.ParseIP(ipStr) - for _, cidr := range wl.CIDRWhitelistOpts.Allow { + for _, cidr := range wl.Allow { if cidr.Contains(ip) { wl.cachedAddr.Store(r.RemoteAddr, true) allow = true @@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool { } if !allow { wl.cachedAddr.Store(r.RemoteAddr, false) - wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow) + wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow) } } if !allow { diff --git a/internal/net/gphttp/middleware/cidr_whitelist_test.go b/internal/net/gphttp/middleware/cidr_whitelist_test.go index a8c7cee..0d73f29 100644 --- a/internal/net/gphttp/middleware/cidr_whitelist_test.go +++ b/internal/net/gphttp/middleware/cidr_whitelist_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ "message": testMessage, }) - ExpectError(t, utils.ErrValidationError, err) + ExpectError(t, serialization.ErrValidationError, err) }) t.Run("invalid cidr", func(t *testing.T) { _, err := CIDRWhiteList.New(OptionsRaw{ @@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) { "status_code": 600, "message": testMessage, }) - ExpectError(t, utils.ErrValidationError, err) + ExpectError(t, serialization.ErrValidationError, err) }) } diff --git a/internal/net/gphttp/middleware/cloudflare_real_ip.go b/internal/net/gphttp/middleware/cloudflare_real_ip.go index 19314d6..59c79d3 100644 --- a/internal/net/gphttp/middleware/cloudflare_real_ip.go +++ b/internal/net/gphttp/middleware/cloudflare_real_ip.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "errors" "fmt" "io" @@ -9,8 +10,8 @@ import ( "sync" "time" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/atomic" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -89,21 +90,29 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { ) if err != nil { cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval)) - logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) + log.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval)) return nil } if len(cfCIDRs) == 0 { - logging.Warn().Msg("cloudflare CIDR range is empty") + log.Warn().Msg("cloudflare CIDR range is empty") } } cfCIDRsLastUpdate.Store(time.Now()) - logging.Info().Msg("cloudflare CIDR range updated") + log.Info().Msg("cloudflare CIDR range updated") return } func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error { - resp, err := http.Get(endpoint) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return err + } + + resp, err := http.DefaultClient.Do(req) //nolint:gosec if err != nil { return err } diff --git a/internal/net/gphttp/middleware/custom_error_page.go b/internal/net/gphttp/middleware/custom_error_page.go index 76110be..baef63f 100644 --- a/internal/net/gphttp/middleware/custom_error_page.go +++ b/internal/net/gphttp/middleware/custom_error_page.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" gphttp "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage" @@ -32,7 +32,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error { if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) { errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode) if ok { - logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode) + log.Debug().Msgf("error page for status %d loaded", resp.StatusCode) _, _ = io.Copy(io.Discard, resp.Body) // drain the original body resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(errorPage)) @@ -40,7 +40,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error { resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage))) resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8") } else { - logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode) + log.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } return nil } @@ -56,7 +56,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo filename := path[len(StaticFilePathPrefix):] file, ok := errorpage.GetStaticFile(filename) if !ok { - logging.Error().Msg("unable to load resource " + filename) + log.Error().Msg("unable to load resource " + filename) return false } ext := filepath.Ext(filename) @@ -68,10 +68,10 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo case ".css": w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8") default: - logging.Error().Msgf("unexpected file type %q for %s", ext, filename) + log.Error().Msgf("unexpected file type %q for %s", ext, filename) } if _, err := w.Write(file); err != nil { - logging.Err(err).Msg("unable to write resource " + filename) + log.Err(err).Msg("unable to write resource " + filename) http.Error(w, "Error page failure", http.StatusInternalServerError) } return true diff --git a/internal/net/gphttp/middleware/errorpage/error_page.go b/internal/net/gphttp/middleware/errorpage/error_page.go index a1acec6..e63b65a 100644 --- a/internal/net/gphttp/middleware/errorpage/error_page.go +++ b/internal/net/gphttp/middleware/errorpage/error_page.go @@ -6,9 +6,9 @@ import ( "path" "sync" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -48,7 +48,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) { func loadContent() { files, err := U.ListFiles(errPagesBasePath, 0) if err != nil { - logging.Err(err).Msg("failed to list error page resources") + log.Err(err).Msg("failed to list error page resources") return } for _, file := range files { @@ -57,11 +57,11 @@ func loadContent() { } content, err := os.ReadFile(file) if err != nil { - logging.Warn().Err(err).Msgf("failed to read error page resource %s", file) + log.Warn().Err(err).Msgf("failed to read error page resource %s", file) continue } file = path.Base(file) - logging.Info().Msgf("error page resource %s loaded", file) + log.Info().Msgf("error page resource %s loaded", file) fileContentMap.Store(file, content) } } @@ -83,9 +83,9 @@ func watchDir() { loadContent() case events.ActionFileDeleted: fileContentMap.Delete(filename) - logging.Warn().Msgf("error page resource %s deleted", filename) + log.Warn().Msgf("error page resource %s deleted", filename) case events.ActionFileRenamed: - logging.Warn().Msgf("error page resource %s deleted", filename) + log.Warn().Msgf("error page resource %s deleted", filename) fileContentMap.Delete(filename) loadContent() } diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index 213aa10..ab7227f 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -8,11 +8,11 @@ import ( "sort" "strings" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" gphttp "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type ( @@ -87,7 +87,7 @@ func NewMiddleware[ImplType any]() *Middleware { func (m *Middleware) enableTrace() { if tracer, ok := m.impl.(MiddlewareWithTracer); ok { tracer.enableTrace() - logging.Trace().Msgf("middleware %s enabled trace", m.name) + log.Trace().Msgf("middleware %s enabled trace", m.name) } } @@ -118,14 +118,14 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error { "priority": optsRaw["priority"], "bypass": optsRaw["bypass"], } - if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil { + if err := serialization.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil { return err } optsRaw = maps.Clone(optsRaw) for k := range commonOpts { delete(optsRaw, k) } - return utils.MapUnmarshalValidate(optsRaw, m.impl) + return serialization.MapUnmarshalValidate(optsRaw, m.impl) } func (m *Middleware) finalize() error { diff --git a/internal/net/gphttp/middleware/middlewares.go b/internal/net/gphttp/middleware/middlewares.go index ca2a85f..8ae4ee9 100644 --- a/internal/net/gphttp/middleware/middlewares.go +++ b/internal/net/gphttp/middleware/middlewares.go @@ -3,9 +3,9 @@ package middleware import ( "path" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -59,7 +59,7 @@ func LoadComposeFiles() { errs := gperr.NewBuilder("middleware compile errors") middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0) if err != nil { - logging.Err(err).Msg("failed to list middleware definitions") + log.Err(err).Msg("failed to list middleware definitions") return } for _, defFile := range middlewareDefs { @@ -75,7 +75,7 @@ func LoadComposeFiles() { continue } allMiddlewares[name] = m - logging.Info(). + log.Info(). Str("src", path.Base(defFile)). Str("name", name). Msg("middleware loaded") @@ -94,7 +94,7 @@ func LoadComposeFiles() { continue } allMiddlewares[name] = m - logging.Info(). + log.Info(). Str("src", path.Base(defFile)). Str("name", name). Msg("middleware loaded") diff --git a/internal/net/gphttp/reverseproxy/reverse_proxy_mod.go b/internal/net/gphttp/reverseproxy/reverse_proxy.go similarity index 96% rename from internal/net/gphttp/reverseproxy/reverse_proxy_mod.go rename to internal/net/gphttp/reverseproxy/reverse_proxy.go index a28d871..59faafe 100644 --- a/internal/net/gphttp/reverseproxy/reverse_proxy_mod.go +++ b/internal/net/gphttp/reverseproxy/reverse_proxy.go @@ -25,7 +25,7 @@ import ( "sync" "github.com/rs/zerolog" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/logging/accesslog" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/types" @@ -138,7 +138,7 @@ func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper panic("nil transport") } rp := &ReverseProxy{ - Logger: logging.With().Str("name", name).Logger(), + Logger: log.With().Str("name", name).Logger(), Transport: transport, TargetName: name, TargetURL: target, @@ -173,17 +173,17 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err case errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, context.DeadlineExceeded): - logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error") + log.Debug().Err(err).Str("url", reqURL).Msg("http proxy error") default: var recordErr tls.RecordHeaderError if errors.As(err, &recordErr) { - logging.Error(). + log.Error(). Str("url", reqURL). Msgf(`scheme was likely misconfigured as https, try setting "proxy.%s.scheme" back to "http"`, p.TargetName) - logging.Err(err).Msg("underlying error") + log.Err(err).Msg("underlying error") } else { - logging.Err(err).Str("url", reqURL).Msg("http proxy error") + log.Err(err).Str("url", reqURL).Msg("http proxy error") } } @@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { transport := p.Transport ctx := req.Context() - /* trunk-ignore(golangci-lint/revive) */ if ctx.Done() != nil { // CloseNotifier predates context.Context, and has been // entirely superseded by it. If the request contains @@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { return nil }, } - outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck res, err := transport.RoundTrip(outreq) @@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R res.Header = rw.Header() res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above if err := res.Write(brw); err != nil { - /* trunk-ignore(golangci-lint/errorlint) */ + //nolint:errorlint p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true) return } if err := brw.Flush(); err != nil { - /* trunk-ignore(golangci-lint/errorlint) */ + //nolint:errorlint p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true) return } bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn) - /* trunk-ignore(golangci-lint/errcheck) */ + //nolint:errcheck bdp.Start() } diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go index c42afec..bf2e7f6 100644 --- a/internal/net/gphttp/server/server.go +++ b/internal/net/gphttp/server/server.go @@ -9,14 +9,14 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/acl" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" ) type CertProvider interface { - GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) + GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) } type Server struct { @@ -53,7 +53,7 @@ func StartServer(parent task.Parent, opt Options) (s *Server) { func NewServer(opt Options) (s *Server) { var httpSer, httpsSer *http.Server - logger := logging.With().Str("server", opt.Name).Logger() + logger := log.With().Str("server", opt.Name).Logger() certAvailable := false if opt.CertProvider != nil { diff --git a/internal/notif/config.go b/internal/notif/config.go index 8bc396e..7f91c3d 100644 --- a/internal/notif/config.go +++ b/internal/notif/config.go @@ -2,7 +2,7 @@ package notif import ( "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type NotificationConfig struct { @@ -46,5 +46,5 @@ func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err gperr.Error) Withf("expect %s or %s", ProviderWebhook, ProviderGotify) } - return utils.MapUnmarshalValidate(m, cfg.Provider) + return serialization.MapUnmarshalValidate(m, cfg.Provider) } diff --git a/internal/notif/config_test.go b/internal/notif/config_test.go index a7afaac..68d3897 100644 --- a/internal/notif/config_test.go +++ b/internal/notif/config_test.go @@ -4,7 +4,7 @@ import ( "net/http" "testing" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -181,7 +181,7 @@ func TestNotificationConfig(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var cfg NotificationConfig provider := tt.cfg["provider"] - err := utils.MapUnmarshalValidate(tt.cfg, &cfg) + err := serialization.MapUnmarshalValidate(tt.cfg, &cfg) if tt.wantErr { ExpectHasError(t, err) } else { diff --git a/internal/notif/providers.go b/internal/notif/providers.go index 84c3c41..be7ba9f 100644 --- a/internal/notif/providers.go +++ b/internal/notif/providers.go @@ -8,14 +8,14 @@ import ( "net/http" "time" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" ) type ( Provider interface { - utils.CustomValidator + serialization.CustomValidator GetName() string GetURL() string @@ -73,7 +73,7 @@ func (msg *LogMessage) notify(ctx context.Context, provider Provider) error { switch resp.StatusCode { case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent: body, _ := io.ReadAll(resp.Body) - logging.Debug(). + log.Debug(). Str("provider", provider.GetName()). Str("url", provider.GetURL()). Str("status", resp.Status). diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go index 36c1811..340b628 100755 --- a/internal/route/provider/docker.go +++ b/internal/route/provider/docker.go @@ -7,12 +7,12 @@ import ( "github.com/docker/docker/client" "github.com/goccy/go-yaml" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/route" - U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" "github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/watcher" ) @@ -36,7 +36,7 @@ func DockerProviderImpl(name, dockerHost string) ProviderImpl { return &DockerProvider{ name, dockerHost, - logging.With().Str("type", "docker").Str("name", name).Logger(), + log.With().Str("type", "docker").Str("name", name).Logger(), } } @@ -106,7 +106,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) { // Always non-nil. func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) { if !container.IsExplicit && p.IsExplicitOnly() { - return nil, nil + return make(route.Routes, 0), nil } routes := make(route.Routes, len(container.Aliases)) @@ -180,7 +180,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) } // deserialize map into entry object - err := U.MapUnmarshalValidate(entryMap, r) + err := serialization.MapUnmarshalValidate(entryMap, r) if err != nil { errs.Add(err.Subject(alias)) } else { @@ -189,7 +189,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) } if wildcardProps != nil { for _, re := range routes { - if err := U.MapUnmarshalValidate(wildcardProps, re); err != nil { + if err := serialization.MapUnmarshalValidate(wildcardProps, re); err != nil { errs.Add(err.Subject(docker.WildcardAlias)) break } diff --git a/internal/route/provider/file.go b/internal/route/provider/file.go index 6c23a60..06921b7 100644 --- a/internal/route/provider/file.go +++ b/internal/route/provider/file.go @@ -6,11 +6,11 @@ import ( "strings" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" W "github.com/yusing/go-proxy/internal/watcher" ) @@ -24,7 +24,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) { impl := &FileProvider{ fileName: filename, path: path.Join(common.ConfigBasePath, filename), - l: logging.With().Str("type", "file").Str("name", filename).Logger(), + l: log.With().Str("type", "file").Str("name", filename).Logger(), } _, err := os.Stat(impl.path) if err != nil { @@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) { } func validate(data []byte) (routes route.Routes, err gperr.Error) { - err = utils.UnmarshalValidateYAML(data, &routes) + err = serialization.UnmarshalValidateYAML(data, &routes) return } diff --git a/internal/route/route.go b/internal/route/route.go index 9ff3eb6..62c41f5 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -7,12 +7,12 @@ import ( "time" "github.com/docker/docker/api/types/container" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/homepage" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" - "github.com/yusing/go-proxy/internal/logging" netutils "github.com/yusing/go-proxy/internal/net" net "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/proxmox" @@ -116,7 +116,7 @@ func (r *Route) Validate() gperr.Error { Subject(containerName) } - l := logging.With().Str("container", containerName).Logger() + l := log.With().Str("container", containerName).Logger() l.Info().Msg("checking if container is running") running, err := node.LXCIsRunning(ctx, vmid) diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go index aa8e8d1..0601e3a 100644 --- a/internal/route/rules/rules_test.go +++ b/internal/route/rules/rules_test.go @@ -3,7 +3,7 @@ package rules import ( "testing" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -28,7 +28,7 @@ func TestParseRule(t *testing.T) { var rules struct { Rules Rules } - err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules) + err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules) ExpectNoError(t, err) ExpectEqual(t, len(rules.Rules), len(test)) ExpectEqual(t, rules.Rules[0].Name, "test") diff --git a/internal/route/stream.go b/internal/route/stream.go index bd87cfe..9c76de0 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -5,9 +5,9 @@ import ( "errors" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/idlewatcher" - "github.com/yusing/go-proxy/internal/logging" net "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" @@ -32,7 +32,7 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) { // TODO: support non-coherent scheme return &StreamRoute{ Route: base, - l: logging.With(). + l: log.With(). Str("type", string(base.Scheme)). Str("name", base.Name()). Logger(), diff --git a/internal/route/types/http_config_test.go b/internal/route/types/http_config_test.go index c6ea8cf..8f41760 100644 --- a/internal/route/types/http_config_test.go +++ b/internal/route/types/http_config_test.go @@ -6,7 +6,7 @@ import ( . "github.com/yusing/go-proxy/internal/route" route "github.com/yusing/go-proxy/internal/route/types" - "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/serialization" expect "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -40,7 +40,7 @@ func TestHTTPConfigDeserialize(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfg := Route{} tt.input["host"] = "internal" - err := utils.MapUnmarshalValidate(tt.input, &cfg) + err := serialization.MapUnmarshalValidate(tt.input, &cfg) if err != nil { expect.NoError(t, err) } diff --git a/internal/route/udp_forwarder.go b/internal/route/udp_forwarder.go index 62149af..581e273 100644 --- a/internal/route/udp_forwarder.go +++ b/internal/route/udp_forwarder.go @@ -6,8 +6,8 @@ import ( "net" "sync" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/types" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -88,7 +88,7 @@ func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) { func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) { buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob) if err == nil { - logging.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn) + log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn) } return } @@ -102,7 +102,7 @@ func (conn *UDPConn) read() (err error) { conn.buf.oobn = 0 } if err == nil { - logging.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) + log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) } return } @@ -110,7 +110,7 @@ func (conn *UDPConn) read() (err error) { func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) { buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr) if err == nil { - logging.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn) + log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn) } return } @@ -120,12 +120,12 @@ func (conn *UDPConn) write() (err error) { case *net.UDPConn: conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil) if err == nil { - logging.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) + log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) } default: _, err = dstConn.Write(conn.buf.data[:conn.buf.n]) if err == nil { - logging.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n) + log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n) } } diff --git a/internal/utils/serialization.go b/internal/serialization/serialization.go similarity index 97% rename from internal/utils/serialization.go rename to internal/serialization/serialization.go index e540d23..54b9631 100644 --- a/internal/utils/serialization.go +++ b/internal/serialization/serialization.go @@ -1,4 +1,4 @@ -package utils +package serialization import ( "encoding/json" @@ -12,7 +12,9 @@ import ( "github.com/go-playground/validator/v10" "github.com/goccy/go-yaml" + "github.com/puzpuzpuz/xsync/v4" "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -40,14 +42,14 @@ var ( var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]() -var defaultValues = functional.NewMapOf[reflect.Type, func() any]() +var defaultValues = xsync.NewMapOf[reflect.Type, func() any]() func RegisterDefaultValueFactory[T any](factory func() *T) { t := reflect.TypeFor[T]() if t.Kind() == reflect.Ptr { panic("pointer of pointer") } - if defaultValues.Has(t) { + if _, ok := defaultValues.Load(t); ok { panic("default value for " + t.String() + " already registered") } defaultValues.Store(t, func() any { return factory() }) @@ -259,7 +261,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool) errs.Add(err.Subject(k)) } } else { - errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping)))) + errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(utils.NearestField(k, mapping)))) } } if hasValidateTag && checkValidateTag { diff --git a/internal/utils/serialization_test.go b/internal/serialization/serialization_test.go similarity index 99% rename from internal/utils/serialization_test.go rename to internal/serialization/serialization_test.go index 6a08779..d067f81 100644 --- a/internal/utils/serialization_test.go +++ b/internal/serialization/serialization_test.go @@ -1,4 +1,4 @@ -package utils +package serialization import ( "reflect" diff --git a/internal/utils/validation.go b/internal/serialization/validation.go similarity index 95% rename from internal/utils/validation.go rename to internal/serialization/validation.go index 490c348..0b222f2 100644 --- a/internal/utils/validation.go +++ b/internal/serialization/validation.go @@ -1,4 +1,4 @@ -package utils +package serialization import ( "github.com/go-playground/validator/v10" diff --git a/internal/task/task.go b/internal/task/task.go index 2421a13..9faa7a6 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -7,9 +7,9 @@ import ( "sync/atomic" "time" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -116,14 +116,14 @@ func (t *Task) Finish(reason any) { func (t *Task) finish(reason any) { t.cancel(fmtCause(reason)) if !waitWithTimeout(t.childrenDone) { - logging.Debug(). + log.Debug(). Str("task", t.name). Strs("subtasks", t.listChildren()). Msg("Timeout waiting for subtasks to finish") } go t.runCallbacks() if !waitWithTimeout(t.callbacksDone) { - logging.Debug(). + log.Debug(). Str("task", t.name). Strs("callbacks", t.listCallbacks()). Msg("Timeout waiting for callbacks to finish") @@ -134,7 +134,7 @@ func (t *Task) finish(reason any) { } t.parent.subChildCount() allTasks.Remove(t) - logging.Trace().Msg("task " + t.name + " finished") + log.Trace().Msg("task " + t.name + " finished") } // Subtask returns a new subtask with the given name, derived from the parent's context. @@ -166,7 +166,7 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task { }() } - logging.Trace().Msg("task " + child.name + " started") + log.Trace().Msg("task " + child.name + " started") return child } @@ -189,7 +189,7 @@ func (t *Task) MarshalText() ([]byte, error) { func (t *Task) invokeWithRecover(fn func(), caller string) { defer func() { if err := recover(); err != nil { - logging.Error(). + log.Error(). Interface("err", err). Msg("panic in task " + t.name + "." + caller) if common.IsDebug { diff --git a/internal/task/utils.go b/internal/task/utils.go index da0bc36..e75b3be 100644 --- a/internal/task/utils.go +++ b/internal/task/utils.go @@ -9,7 +9,7 @@ import ( "syscall" "time" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" F "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -68,10 +68,10 @@ func GracefulShutdown(timeout time.Duration) (err error) { case <-after: b, err := json.Marshal(DebugTaskList()) if err != nil { - logging.Warn().Err(err).Msg("failed to marshal tasks") + log.Warn().Err(err).Msg("failed to marshal tasks") return context.DeadlineExceeded } - logging.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size()) + log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size()) return context.DeadlineExceeded } } @@ -87,6 +87,6 @@ func WaitExit(shutdownTimeout int) { <-sig // gracefully shutdown - logging.Info().Msg("shutting down") + log.Info().Msg("shutting down") _ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout)) } diff --git a/internal/utils/go.mod b/internal/utils/go.mod new file mode 100644 index 0000000..6ccf74c --- /dev/null +++ b/internal/utils/go.mod @@ -0,0 +1,21 @@ +module github.com/yusing/go-proxy/internal/utils + +go 1.24.3 + +require ( + github.com/goccy/go-yaml v1.17.1 + github.com/puzpuzpuz/xsync/v4 v4.1.0 + github.com/rs/zerolog v1.34.0 + github.com/stretchr/testify v1.10.0 + go.uber.org/atomic v1.11.0 + golang.org/x/text v0.25.0 +) + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + golang.org/x/sys v0.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/internal/utils/go.sum b/internal/utils/go.sum new file mode 100644 index 0000000..7477df6 --- /dev/null +++ b/internal/utils/go.sum @@ -0,0 +1,36 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= +github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/puzpuzpuz/xsync/v4 v4.1.0 h1:x9eHRl4QhZFIPJ17yl4KKW9xLyVWbb3/Yq4SXpjF71U= +github.com/puzpuzpuz/xsync/v4 v4.1.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/utils/io.go b/internal/utils/io.go index 5d8488a..33881fc 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -8,7 +8,6 @@ import ( "sync" "syscall" - "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/utils/synk" ) @@ -91,20 +90,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re } } -func (p BidirectionalPipe) Start() gperr.Error { +func (p BidirectionalPipe) Start() error { var wg sync.WaitGroup wg.Add(2) - b := gperr.NewBuilder("bidirectional pipe error") + var srcErr, dstErr error go func() { - b.Add(p.pSrcDst.Start()) + srcErr = p.pSrcDst.Start() wg.Done() }() go func() { - b.Add(p.pDstSrc.Start()) + dstErr = p.pDstSrc.Start() wg.Done() }() wg.Wait() - return b.Error() + return errors.Join(srcErr, dstErr) } type httpFlusher interface { @@ -143,30 +142,18 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { wCloser, wCanClose := dst.Writer.(io.Closer) rCloser, rCanClose := src.Reader.(io.Closer) if wCanClose || rCanClose { - if src.ctx == dst.ctx { - go func() { - <-src.ctx.Done() - if wCanClose { - wCloser.Close() - } - if rCanClose { - rCloser.Close() - } - }() - } else { - if wCloser != nil { - go func() { - <-src.ctx.Done() - wCloser.Close() - }() + go func() { + select { + case <-src.ctx.Done(): + case <-dst.ctx.Done(): } - if rCloser != nil { - go func() { - <-dst.ctx.Done() - rCloser.Close() - }() + if rCanClose { + defer rCloser.Close() } - } + if wCanClose { + defer wCloser.Close() + } + }() } flusher := getHTTPFlusher(dst.Writer) canFlush := flusher != nil diff --git a/internal/utils/pool/pool.go b/internal/utils/pool/pool.go index 7abe4bc..84056e9 100644 --- a/internal/utils/pool/pool.go +++ b/internal/utils/pool/pool.go @@ -4,7 +4,7 @@ import ( "sort" "github.com/puzpuzpuz/xsync/v4" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" ) type ( @@ -29,12 +29,12 @@ func (p Pool[T]) Name() string { func (p Pool[T]) Add(obj T) { p.checkExists(obj.Key()) p.m.Store(obj.Key(), obj) - logging.Info().Msgf("%s: added %s", p.name, obj.Name()) + log.Info().Msgf("%s: added %s", p.name, obj.Name()) } func (p Pool[T]) Del(obj T) { p.m.Delete(obj.Key()) - logging.Info().Msgf("%s: removed %s", p.name, obj.Name()) + log.Info().Msgf("%s: removed %s", p.name, obj.Name()) } func (p Pool[T]) Get(key string) (T, bool) { diff --git a/internal/utils/pool/pool_debug.go b/internal/utils/pool/pool_debug.go index 6ab84f6..bb444db 100644 --- a/internal/utils/pool/pool_debug.go +++ b/internal/utils/pool/pool_debug.go @@ -5,11 +5,11 @@ package pool import ( "runtime/debug" - "github.com/yusing/go-proxy/internal/logging" + "github.com/rs/zerolog/log" ) func (p Pool[T]) checkExists(key string) { if _, ok := p.m.Load(key); ok { - logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack())) + log.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack())) } } diff --git a/internal/utils/synk/pool.go b/internal/utils/synk/pool.go index 11e282a..15cbd64 100644 --- a/internal/utils/synk/pool.go +++ b/internal/utils/synk/pool.go @@ -1,16 +1,35 @@ package synk import ( - "os" - "os/signal" - "sync/atomic" - "time" - - "github.com/yusing/go-proxy/internal/logging" + "runtime" + "unsafe" ) +type weakBuf = unsafe.Pointer + +func makeWeak(b *[]byte) weakBuf { + ptr := runtime_registerWeakPointer(unsafe.Pointer(b)) + runtime.KeepAlive(ptr) + addCleanup(b, addGCed, cap(*b)) + return weakBuf(ptr) +} + +func getBufFromWeak(w weakBuf) []byte { + ptr := (*[]byte)(runtime_makeStrongFromWeak(w)) + if ptr == nil { + return nil + } + return *ptr +} + +//go:linkname runtime_registerWeakPointer weak.runtime_registerWeakPointer +func runtime_registerWeakPointer(unsafe.Pointer) unsafe.Pointer + +//go:linkname runtime_makeStrongFromWeak weak.runtime_makeStrongFromWeak +func runtime_makeStrongFromWeak(unsafe.Pointer) unsafe.Pointer + type BytesPool struct { - pool chan []byte + pool chan weakBuf initSize int } @@ -22,19 +41,15 @@ const ( const ( InPoolLimit = 32 * mb - DefaultInitBytes = 32 * kb - PoolThreshold = 64 * kb + DefaultInitBytes = 4 * kb + PoolThreshold = 256 * kb DropThresholdHigh = 4 * mb PoolSize = InPoolLimit / PoolThreshold - - CleanupInterval = 5 * time.Second - MaxDropsPerCycle = 10 - MaxChecksPerCycle = 100 ) var bytesPool = &BytesPool{ - pool: make(chan []byte, PoolSize), + pool: make(chan weakBuf, PoolSize), initSize: DefaultInitBytes, } @@ -43,12 +58,18 @@ func NewBytesPool() *BytesPool { } func (p *BytesPool) Get() []byte { - select { - case b := <-p.pool: - subInPoolSize(int64(cap(b))) - return b - default: - return make([]byte, 0, p.initSize) + for { + select { + case bWeak := <-p.pool: + bPtr := getBufFromWeak(bWeak) + if bPtr == nil { + continue + } + addReused(cap(bPtr)) + return bPtr + default: + return make([]byte, 0, p.initSize) + } } } @@ -56,90 +77,43 @@ func (p *BytesPool) GetSized(size int) []byte { if size <= PoolThreshold { return make([]byte, size) } - select { - case b := <-p.pool: - if size <= cap(b) { - subInPoolSize(int64(cap(b))) - return b[:size] - } + for { select { - case p.pool <- b: - addInPoolSize(int64(cap(b))) + case bWeak := <-p.pool: + bPtr := getBufFromWeak(bWeak) + if bPtr == nil { + continue + } + capB := cap(bPtr) + if capB >= size { + addReused(capB) + return (bPtr)[:size] + } + select { + case p.pool <- bWeak: + default: + // just drop it + } default: } - default: + return make([]byte, size) } - return make([]byte, size) } func (p *BytesPool) Put(b []byte) { size := cap(b) - if size > DropThresholdHigh || poolFull() { + if size <= PoolThreshold || size > DropThresholdHigh { return } b = b[:0] + w := makeWeak(&b) select { - case p.pool <- b: - addInPoolSize(int64(size)) - return + case p.pool <- w: default: // just drop it } } -var inPoolSize int64 - -func addInPoolSize(size int64) { - atomic.AddInt64(&inPoolSize, size) -} - -func subInPoolSize(size int64) { - atomic.AddInt64(&inPoolSize, -size) -} - func init() { - // Periodically drop some buffers to prevent excessive memory usage - go func() { - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, os.Interrupt) - - cleanupTicker := time.NewTicker(CleanupInterval) - defer cleanupTicker.Stop() - - for { - select { - case <-cleanupTicker.C: - dropBuffers() - case <-sigCh: - return - } - } - }() -} - -func poolFull() bool { - return atomic.LoadInt64(&inPoolSize) >= InPoolLimit -} - -// dropBuffers removes excess buffers from the pool when it grows too large. -func dropBuffers() { - // Check if pool has more than a threshold of buffers - count := 0 - droppedSize := 0 - checks := 0 - for count < MaxDropsPerCycle && checks < MaxChecksPerCycle && atomic.LoadInt64(&inPoolSize) > InPoolLimit*2/3 { - select { - case b := <-bytesPool.pool: - n := cap(b) - subInPoolSize(int64(n)) - droppedSize += n - count++ - default: - time.Sleep(10 * time.Millisecond) - } - checks++ - } - if count > 0 { - logging.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool") - } + initPoolStats() } diff --git a/internal/utils/synk/pool_bench_test.go b/internal/utils/synk/pool_bench_test.go index 6bb6e02..7339a23 100644 --- a/internal/utils/synk/pool_bench_test.go +++ b/internal/utils/synk/pool_bench_test.go @@ -4,7 +4,7 @@ import ( "testing" ) -var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024} +var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024} func BenchmarkBytesPool_GetSmall(b *testing.B) { for b.Loop() { diff --git a/internal/utils/synk/pool_debug.go b/internal/utils/synk/pool_debug.go new file mode 100644 index 0000000..0a7aa00 --- /dev/null +++ b/internal/utils/synk/pool_debug.go @@ -0,0 +1,55 @@ +//go:build !production + +package synk + +import ( + "os" + "os/signal" + "runtime" + "sync/atomic" + "time" + + "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +var ( + numReused, sizeReused uint64 + numGCed, sizeGCed uint64 +) + +func addReused(size int) { + atomic.AddUint64(&numReused, 1) + atomic.AddUint64(&sizeReused, uint64(size)) +} + +func addGCed(size int) { + atomic.AddUint64(&numGCed, 1) + atomic.AddUint64(&sizeGCed, uint64(size)) +} + +var addCleanup = runtime.AddCleanup[[]byte, int] + +func initPoolStats() { + go func() { + statsTicker := time.NewTicker(5 * time.Second) + defer statsTicker.Stop() + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + + for { + select { + case <-sig: + return + case <-statsTicker.C: + log.Info(). + Uint64("numReused", atomic.LoadUint64(&numReused)). + Str("sizeReused", strutils.FormatByteSize(atomic.LoadUint64(&sizeReused))). + Uint64("numGCed", atomic.LoadUint64(&numGCed)). + Str("sizeGCed", strutils.FormatByteSize(atomic.LoadUint64(&sizeGCed))). + Msg("bytes pool stats") + } + } + }() +} diff --git a/internal/utils/synk/pool_prod.go b/internal/utils/synk/pool_prod.go new file mode 100644 index 0000000..3cca12d --- /dev/null +++ b/internal/utils/synk/pool_prod.go @@ -0,0 +1,8 @@ +//go:build production + +package synk + +func addReused(size int) {} +func addGCed(size int) {} +func initPoolStats() {} +func addCleanup(ptr *[]byte, cleanup func(int), arg int) {} diff --git a/internal/utils/testing/expect.go b/internal/utils/testing/expect.go index a87ddaa..4455d7c 100644 --- a/internal/utils/testing/expect.go +++ b/internal/utils/testing/expect.go @@ -2,14 +2,16 @@ package expect import ( "os" + "strings" "testing" "github.com/stretchr/testify/require" - "github.com/yusing/go-proxy/internal/common" ) +var isTest = strings.HasSuffix(os.Args[0], ".test") + func init() { - if common.IsTest { + if isTest { // force verbose output os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...) } diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index 52f118c..13411c3 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -1,19 +1,11 @@ package expect import ( - "os" "testing" "github.com/stretchr/testify/require" - "github.com/yusing/go-proxy/internal/common" ) -func init() { - if common.IsTest { - os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...) - } -} - func ExpectNoError(t *testing.T, err error) { t.Helper() require.NoError(t, err) diff --git a/internal/utils/time_now.go b/internal/utils/time_now.go index 8b5e155..00bcff3 100644 --- a/internal/utils/time_now.go +++ b/internal/utils/time_now.go @@ -3,7 +3,6 @@ package utils import ( "time" - "github.com/yusing/go-proxy/internal/task" "go.uber.org/atomic" ) @@ -38,8 +37,6 @@ func init() { go func() { for { select { - case <-task.RootContext().Done(): - return case <-timeNowTicker.C: shouldCallTimeNow.Store(true) } diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index 45cca6a..4bfee33 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -8,8 +8,8 @@ import ( "github.com/fsnotify/fsnotify" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/events" ) @@ -41,13 +41,13 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() if err != nil { - logging.Panic().Err(err).Msg("unable to create fs watcher") + log.Panic().Err(err).Msg("unable to create fs watcher") } if err = w.Add(dirPath); err != nil { - logging.Panic().Err(err).Msg("unable to create fs watcher") + log.Panic().Err(err).Msg("unable to create fs watcher") } helper := &DirWatcher{ - Logger: logging.With(). + Logger: log.With(). Str("type", "dir"). Str("path", dirPath). Logger(), diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go index 09713c6..1a01f9e 100644 --- a/internal/watcher/docker_watcher.go +++ b/internal/watcher/docker_watcher.go @@ -8,9 +8,9 @@ import ( docker_events "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/watcher/events" ) @@ -81,7 +81,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList }() cEventCh, cErrCh := client.Events(ctx, options) - defer logging.Debug().Str("host", client.Address()).Msg("docker watcher closed") + defer log.Debug().Str("host", client.Address()).Msg("docker watcher closed") for { select { case <-ctx.Done(): @@ -153,7 +153,7 @@ func checkConnection(ctx context.Context, client *docker.SharedClient) bool { defer cancel() err := client.CheckConnection(ctx) if err != nil { - logging.Debug().Err(err).Msg("docker watcher: connection failed") + log.Debug().Err(err).Msg("docker watcher: connection failed") return false } return true diff --git a/internal/watcher/health/monitor/monitor.go b/internal/watcher/health/monitor/monitor.go index a6b029f..9e0ead7 100644 --- a/internal/watcher/health/monitor/monitor.go +++ b/internal/watcher/health/monitor/monitor.go @@ -7,9 +7,9 @@ import ( "net/url" "time" + "github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/task" @@ -48,7 +48,7 @@ func NewMonitor(r routes.Route) health.HealthMonCheck { case routes.StreamRoute: mon = NewRawHealthMonitor(&r.TargetURL().URL, r.HealthCheckConfig()) default: - logging.Panic().Msgf("unexpected route type: %T", r) + log.Panic().Msgf("unexpected route type: %T", r) } } if r.IsDocker() { @@ -91,7 +91,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error { mon.task = parent.Subtask("health_monitor", true) go func() { - logger := logging.With().Str("name", mon.service).Logger() + logger := log.With().Str("name", mon.service).Logger() defer func() { if mon.status.Load() != health.StatusError { @@ -221,7 +221,7 @@ func (mon *monitor) MarshalJSON() ([]byte, error) { } func (mon *monitor) checkUpdateHealth() error { - logger := logging.With().Str("name", mon.Name()).Logger() + logger := log.With().Str("name", mon.Name()).Logger() result, err := mon.checkHealth() var lastStatus health.Status diff --git a/pkg/version.go b/pkg/version.go index 03c42ec..c037f0e 100644 --- a/pkg/version.go +++ b/pkg/version.go @@ -18,7 +18,7 @@ func GetLastVersion() Version { func GetVersionHTTPHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(GetVersion().String())) + fmt.Fprint(w, GetVersion().String()) } } @@ -34,16 +34,16 @@ func init() { // lastVersion = ParseVersion(lastVersionStr) // } // if err != nil && !os.IsNotExist(err) { - // logging.Warn().Err(err).Msg("failed to read version file") + // log.Warn().Err(err).Msg("failed to read version file") // return // } // if err := f.Truncate(0); err != nil { - // logging.Warn().Err(err).Msg("failed to truncate version file") + // log.Warn().Err(err).Msg("failed to truncate version file") // return // } // _, err = f.WriteString(version) // if err != nil { - // logging.Warn().Err(err).Msg("failed to save version file") + // log.Warn().Err(err).Msg("failed to save version file") // return // } } diff --git a/socket-proxy/go.mod b/socket-proxy/go.mod index cd8a7de..5c83f0b 100644 --- a/socket-proxy/go.mod +++ b/socket-proxy/go.mod @@ -2,4 +2,19 @@ module github.com/yusing/go-proxy/socketproxy go 1.24.3 -require github.com/gorilla/mux v1.8.1 +replace github.com/yusing/go-proxy/internal/utils => ../internal/utils + +require ( + github.com/gorilla/mux v1.8.1 + github.com/yusing/go-proxy/internal/utils v0.0.0-00010101000000-000000000000 + golang.org/x/net v0.40.0 +) + +require ( + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/rs/zerolog v1.34.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect +) diff --git a/socket-proxy/go.sum b/socket-proxy/go.sum index 7128337..7967116 100644 --- a/socket-proxy/go.sum +++ b/socket-proxy/go.sum @@ -1,2 +1,34 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/socket-proxy/pkg/handler.go b/socket-proxy/pkg/handler.go index 51c2049..930bc9e 100644 --- a/socket-proxy/pkg/handler.go +++ b/socket-proxy/pkg/handler.go @@ -4,30 +4,33 @@ import ( "context" "net" "net/http" - "net/http/httputil" "strings" "time" "github.com/gorilla/mux" + "github.com/yusing/go-proxy/socketproxy/pkg/reverseproxy" ) var dialer = &net.Dialer{KeepAlive: 1 * time.Second} -func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, "unix", DockerSocket) +func dialDockerSocket(socket string) func(ctx context.Context, _, _ string) (net.Conn, error) { + return func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", socket) + } } var DockerSocketHandler = dockerSocketHandler -func dockerSocketHandler() http.HandlerFunc { - rp := &httputil.ReverseProxy{ +func dockerSocketHandler(socket string) http.HandlerFunc { + rp := &reverseproxy.ReverseProxy{ Director: func(req *http.Request) { req.URL.Scheme = "http" req.URL.Host = "api.moby.localhost" req.RequestURI = req.URL.String() }, Transport: &http.Transport{ - DialContext: dialDockerSocket, + DialContext: dialDockerSocket(socket), + DisableCompression: true, }, } @@ -41,7 +44,7 @@ func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) { // ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg func NewHandler() http.Handler { r := mux.NewRouter() - socketHandler := DockerSocketHandler() + socketHandler := DockerSocketHandler(DockerSocket) const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}` const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}" diff --git a/socket-proxy/pkg/handler_test.go b/socket-proxy/pkg/handler_test.go index 59748e4..4c4cc0a 100644 --- a/socket-proxy/pkg/handler_test.go +++ b/socket-proxy/pkg/handler_test.go @@ -9,7 +9,7 @@ import ( . "github.com/yusing/go-proxy/socketproxy/pkg" ) -func mockDockerSocketHandler() http.HandlerFunc { +func mockDockerSocketHandler(_ string) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("mock docker response")) diff --git a/socket-proxy/pkg/reverseproxy/reverse_proxy.go b/socket-proxy/pkg/reverseproxy/reverse_proxy.go new file mode 100644 index 0000000..b2abaff --- /dev/null +++ b/socket-proxy/pkg/reverseproxy/reverse_proxy.go @@ -0,0 +1,367 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// License URL: https://cs.opensource.google/go/go/+/master:LICENSE + +// HTTP reverse proxy handler + +package reverseproxy + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptrace" + "net/textproto" + "strings" + "sync" + + "github.com/yusing/go-proxy/internal/utils" + "golang.org/x/net/http/httpguts" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +// +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. +type ReverseProxy struct { + // Director is a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + // + // By default, the X-Forwarded-For header is set to the + // value of the client IP address. If an X-Forwarded-For + // header already exists, the client IP is appended to the + // existing values. As a special case, if the header + // exists in the Request.Header map but has a nil value + // (such as when set by the Director func), the X-Forwarded-For + // header is not modified. + // + // To prevent IP spoofing, be sure to delete any pre-existing + // X-Forwarded-For header coming from the client or + // an untrusted proxy. + // + // Hop-by-hop headers are removed from the request after + // Director returns, which can remove headers added by + // Director. Use a Rewrite function instead to ensure + // modifications to the request are preserved. + // + // Unparsable query parameters are removed from the outbound + // request if Request.Form is set after Director returns. + // + // At most one of Rewrite or Director may be set. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + ctx := req.Context() + + outreq := req.Clone(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Body != nil { + // Reading from the request body after returning from a handler is not + // allowed, and the RoundTrip goroutine that reads the Body can outlive + // this handler. This can lead to a crash if the handler panics (see + // Issue 46866). Although calling Close doesn't guarantee there isn't + // any Read in flight after the handle returns, in practice it's safe to + // read after closing it. + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + p.Director(outreq) + outreq.Close = false + + reqUpType := upgradeType(outreq.Header) + if !IsPrint(reqUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) + return + } + + req.Header.Del("Forwarded") + removeHopByHopHeaders(outreq.Header) + + // Issue 21096: tell backend applications that care about trailer support + // that we support trailers. (We do, but we don't go out of our way to + // advertise that unless the incoming client request thought it was worth + // mentioning.) Note that we look at req.Header, not outreq.Header, since + // the latter has passed through removeHopByHopHeaders. + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + + if _, ok := outreq.Header["User-Agent"]; !ok { + // If the outbound request doesn't have a User-Agent header set, + // don't send the default Go HTTP client User-Agent. + outreq.Header.Set("User-Agent", "") + } + + var ( + roundTripMutex sync.Mutex + roundTripDone bool + ) + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMutex.Lock() + defer roundTripMutex.Unlock() + if roundTripDone { + // If RoundTrip has returned, don't try to further modify + // the ResponseWriter's header map. + return nil + } + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + clear(h) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + roundTripMutex.Lock() + roundTripDone = true + roundTripMutex.Unlock() + if err != nil { + p.getErrorHandler()(rw, outreq, err) + return + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + p.handleUpgradeResponse(rw, outreq, res) + return + } + + removeHopByHopHeaders(res.Header) + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + err = utils.CopyCloseWithContext(ctx, rw, res.Body) + if err != nil { + if !errors.Is(err, context.Canceled) { + p.getErrorHandler()(rw, req, err) + } + return + } + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + http.NewResponseController(rw).Flush() + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +// removeHopByHopHeaders removes hop-by-hop headers. +func removeHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for sf := range strings.SplitSeq(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} + +func (p *ReverseProxy) logf(format string, args ...any) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return h.Get("Upgrade") +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)) + return + } + if !strings.EqualFold(reqUpType, resUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + + if hijackErr != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) + return + } + defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} + +func IsPrint(s string) bool { + for _, r := range s { + if r < ' ' || r > '~' { + return false + } + } + return true +}