diff --git a/.golangci.yml b/.golangci.yml
index 078682a..8eb57f6 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -2,15 +2,16 @@ version: "2"
linters:
default: all
disable:
- - bodyclose
+ # - bodyclose
- containedctx
- - contextcheck
+ # - contextcheck
- cyclop
- depguard
- - dupl
+ # - dupl
- err113
- exhaustive
- exhaustruct
+ - funcorder
- forcetypeassert
- gochecknoglobals
- gochecknoinits
@@ -18,7 +19,6 @@ linters:
- goconst
- gocyclo
- gomoddirectives
- - gosec
- gosmopolitan
- ireturn
- lll
@@ -27,12 +27,10 @@ linters:
- mnd
- nakedret
- nestif
- - nilnil
- nlreturn
- - noctx
- nonamedreturns
- paralleltest
- - prealloc
+ - revive
- rowserrcheck
- sqlclosecheck
- tagliatelle
diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml
index bf24f6d..ed25d81 100644
--- a/.trunk/trunk.yaml
+++ b/.trunk/trunk.yaml
@@ -21,7 +21,7 @@ lint:
- markdownlint
- yamllint
enabled:
- - checkov@3.2.416
+ - checkov@3.2.432
- golangci-lint2@2.1.6
- hadolint@2.12.1-beta
- actionlint@1.7.7
@@ -32,7 +32,7 @@ lint:
- prettier@3.5.3
- shellcheck@0.10.0
- shfmt@3.6.0
- - trufflehog@3.88.29
+ - trufflehog@3.88.33
actions:
disabled:
- trunk-announce
diff --git a/README.md b/README.md
index d091204..85e5172 100755
--- a/README.md
+++ b/README.md
@@ -16,6 +16,8 @@ A lightweight, simple, and performant reverse proxy with WebUI.
EN | 中文
+Have questions? Ask [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)! (Thanks to [@ismesid](https://github.com/arevindh))
+
diff --git a/README_CHT.md b/README_CHT.md
index f9e0ca2..bd95fc4 100644
--- a/README_CHT.md
+++ b/README_CHT.md
@@ -16,6 +16,8 @@
EN | 中文
+有疑問? 問 [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)!(鳴謝 [@ismesid](https://github.com/arevindh))
+
diff --git a/agent/cmd/main.go b/agent/cmd/main.go
index 7c891b8..d2167b6 100644
--- a/agent/cmd/main.go
+++ b/agent/cmd/main.go
@@ -1,11 +1,14 @@
package main
import (
+ "os"
+
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/server"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/task"
@@ -14,6 +17,12 @@ import (
)
func main() {
+ writer := zerolog.ConsoleWriter{
+ Out: os.Stderr,
+ TimeFormat: "01-02 15:04",
+ }
+ zerolog.TimeFieldFormat = writer.TimeFormat
+ log.Logger = zerolog.New(writer).Level(zerolog.InfoLevel).With().Timestamp().Logger()
ca := &agent.PEMPair{}
err := ca.Load(env.AgentCACert)
if err != nil {
@@ -34,11 +43,11 @@ func main() {
gperr.LogFatal("init SSL error", err)
}
- logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
- logging.Info().Msgf("Agent name: %s", env.AgentName)
- logging.Info().Msgf("Agent port: %d", env.AgentPort)
+ log.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
+ log.Info().Msgf("Agent name: %s", env.AgentName)
+ log.Info().Msgf("Agent port: %d", env.AgentPort)
- logging.Info().Msg(`
+ log.Info().Msg(`
Tips:
1. To change the agent name, you can set the AGENT_NAME environment variable.
2. To change the agent port, you can set the AGENT_PORT environment variable.
@@ -54,7 +63,7 @@ Tips:
server.StartAgentServer(t, opts)
if socketproxy.ListenAddr != "" {
- logging.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
+ log.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
opts := httpServer.Options{
Name: "docker",
HTTPAddr: socketproxy.ListenAddr,
diff --git a/agent/go.mod b/agent/go.mod
index 567f762..8b6e7d0 100644
--- a/agent/go.mod
+++ b/agent/go.mod
@@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy => ..
replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy
+replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
+
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250523121925-f87c3159e327
@@ -15,6 +17,7 @@ require (
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
+ github.com/yusing/go-proxy/internal/utils v0.0.0
github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000
)
diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go
index 9840486..5887a66 100644
--- a/agent/pkg/agent/config.go
+++ b/agent/pkg/agent/config.go
@@ -15,8 +15,8 @@ import (
"time"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/certs"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/pkg"
)
@@ -115,7 +115,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
cfg.name = string(name)
- cfg.l = logging.With().Str("agent", cfg.name).Logger()
+ cfg.l = log.With().Str("agent", cfg.name).Logger()
// check agent version
agentVersionBytes, _, err := cfg.Fetch(ctx, EndpointVersion)
@@ -127,10 +127,10 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
agentVersion := pkg.ParseVersion(cfg.version)
if serverVersion.IsNewerMajorThan(agentVersion) {
- logging.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
+ log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
}
- logging.Info().Msgf("agent %q initialized", cfg.name)
+ log.Info().Msgf("agent %q initialized", cfg.name)
return nil
}
diff --git a/agent/pkg/handler/check_health_test.go b/agent/pkg/handler/check_health_test.go
index 4633ed6..2fc023f 100644
--- a/agent/pkg/handler/check_health_test.go
+++ b/agent/pkg/handler/check_health_test.go
@@ -172,9 +172,9 @@ func TestCheckHealthTCPUDP(t *testing.T) {
{
name: "InvalidHost",
scheme: "tcp",
- host: "invalid",
+ host: "",
port: 8080,
- expectedStatus: http.StatusOK,
+ expectedStatus: http.StatusBadRequest,
expectedHealthy: false,
},
{
@@ -188,9 +188,17 @@ func TestCheckHealthTCPUDP(t *testing.T) {
{
name: "InvalidHost",
scheme: "udp",
- host: "invalid",
+ host: "",
port: 8080,
- expectedStatus: http.StatusOK,
+ expectedStatus: http.StatusBadRequest,
+ expectedHealthy: false,
+ },
+ {
+ name: "Port in both host and port",
+ scheme: "tcp",
+ host: "localhost:1234",
+ port: 1234,
+ expectedStatus: http.StatusBadRequest,
expectedHealthy: false,
},
}
@@ -208,9 +216,11 @@ func TestCheckHealthTCPUDP(t *testing.T) {
require.Equal(t, recorder.Code, tt.expectedStatus)
- var result health.HealthCheckResult
- require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
- require.Equal(t, result.Healthy, tt.expectedHealthy)
+ if tt.expectedStatus == http.StatusOK {
+ var result health.HealthCheckResult
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
+ require.Equal(t, result.Healthy, tt.expectedHealthy)
+ }
})
}
}
diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go
index 9d07556..1835616 100644
--- a/agent/pkg/handler/handler.go
+++ b/agent/pkg/handler/handler.go
@@ -1,17 +1,14 @@
package handler
import (
- "context"
"fmt"
- "net"
"net/http"
- "net/http/httputil"
- "time"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/pkg"
+ socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
)
type ServeMux struct{ *http.ServeMux }
@@ -24,26 +21,6 @@ func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
}
-var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
-
-func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
- return dialer.DialContext(ctx, "unix", env.DockerSocket)
-}
-
-func dockerSocketHandler() http.HandlerFunc {
- rp := httputil.ReverseProxy{
- Director: func(r *http.Request) {
- r.URL.Scheme = "http"
- r.URL.Host = "api.moby.localhost"
- r.RequestURI = r.URL.String()
- },
- Transport: &http.Transport{
- DialContext: dialDockerSocket,
- },
- }
- return rp.ServeHTTP
-}
-
func NewAgentHandler() http.Handler {
mux := ServeMux{http.NewServeMux()}
@@ -54,6 +31,6 @@ func NewAgentHandler() http.Handler {
})
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
- mux.ServeMux.HandleFunc("/", dockerSocketHandler())
+ mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler(env.DockerSocket))
return mux
}
diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go
index 9be4631..4b9e2b7 100644
--- a/agent/pkg/server/server.go
+++ b/agent/pkg/server/server.go
@@ -6,9 +6,9 @@ import (
"fmt"
"net/http"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/handler"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/task"
)
@@ -33,12 +33,11 @@ func StartAgentServer(parent task.Parent, opt Options) {
tlsConfig.ClientAuth = tls.NoClientCert
}
- logger := logging.GetLogger()
agentServer := &http.Server{
Addr: fmt.Sprintf(":%d", opt.Port),
Handler: handler.NewAgentHandler(),
TLSConfig: tlsConfig,
}
- server.Start(parent, agentServer, nil, logger)
+ server.Start(parent, agentServer, nil, &log.Logger)
}
diff --git a/cmd/main.go b/cmd/main.go
index 125ec69..c2cb3c1 100755
--- a/cmd/main.go
+++ b/cmd/main.go
@@ -4,6 +4,7 @@ import (
"os"
"sync"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/auth"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
@@ -35,8 +36,8 @@ func main() {
initProfiling()
logging.InitLogger(os.Stderr, memlogger.GetMemLogger())
- logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
- logging.Trace().Msg("trace enabled")
+ log.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
+ log.Trace().Msg("trace enabled")
parallel(
dnsproviders.InitProviders,
homepage.InitIconListCache,
@@ -45,7 +46,7 @@ func main() {
)
if common.APIJWTSecret == nil {
- logging.Warn().Msg("API_JWT_SECRET is not set, using random key")
+ log.Warn().Msg("API_JWT_SECRET is not set, using random key")
common.APIJWTSecret = common.RandomJWTKey()
}
@@ -62,7 +63,7 @@ func main() {
Proxy: true,
})
if err := auth.Initialize(); err != nil {
- logging.Fatal().Err(err).Msg("failed to initialize authentication")
+ log.Fatal().Err(err).Msg("failed to initialize authentication")
}
// API Handler needs to start after auth is initialized.
cfg.StartServers(&config.StartServersOptions{
@@ -78,7 +79,7 @@ func main() {
func prepareDirectory(dir string) {
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err = os.MkdirAll(dir, 0o755); err != nil {
- logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
+ log.Fatal().Msgf("failed to create directory %s: %v", dir, err)
}
}
}
diff --git a/go.mod b/go.mod
index 38af281..df26fd2 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy/agent => ./agent
replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders
+replace github.com/yusing/go-proxy/internal/utils => ./internal/utils
+
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.0.0-20250523122447-f078841dec22
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
@@ -45,7 +47,7 @@ require (
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy/agent v0.0.0-00010101000000-000000000000
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-00010101000000-000000000000
- go.uber.org/atomic v1.11.0
+ github.com/yusing/go-proxy/internal/utils v0.0.0
)
require (
@@ -219,6 +221,7 @@ require (
go.opentelemetry.io/otel v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
+ go.uber.org/atomic v1.11.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/mock v0.5.2 // indirect
go.uber.org/multierr v1.11.0 // indirect
@@ -226,7 +229,7 @@ require (
golang.org/x/mod v0.24.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
- golang.org/x/text v0.25.0
+ golang.org/x/text v0.25.0 // indirect
golang.org/x/tools v0.33.0 // indirect
google.golang.org/api v0.234.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect
diff --git a/internal/acl/config.go b/internal/acl/config.go
index 2b7c5ef..8d2f306 100644
--- a/internal/acl/config.go
+++ b/internal/acl/config.go
@@ -5,9 +5,9 @@ import (
"time"
"github.com/puzpuzpuz/xsync/v4"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/task"
@@ -45,7 +45,7 @@ func (c *checkCache) Expired() bool {
return c.created.Add(cacheTTL).Before(utils.TimeNow())
}
-//TODO: add stats
+// TODO: add stats
const (
ACLAllow = "allow"
@@ -97,7 +97,7 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
if c.valErr != nil {
return c.valErr
}
- logging.Info().
+ log.Info().
Str("default", c.Default).
Bool("allow_local", c.allowLocal).
Int("allow_rules", len(c.Allow)).
diff --git a/internal/acl/matcher_test.go b/internal/acl/matcher_test.go
index 384b9b8..0d015b4 100644
--- a/internal/acl/matcher_test.go
+++ b/internal/acl/matcher_test.go
@@ -6,7 +6,7 @@ import (
"testing"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
func TestMatchers(t *testing.T) {
@@ -16,7 +16,7 @@ func TestMatchers(t *testing.T) {
}
var mathers Matchers
- err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
+ err := serialization.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/acl/tcp_listener.go b/internal/acl/tcp_listener.go
index b6b4a7a..59e40ee 100644
--- a/internal/acl/tcp_listener.go
+++ b/internal/acl/tcp_listener.go
@@ -22,12 +22,12 @@ func (noConn) SetDeadline(t time.Time) error { return nil }
func (noConn) SetReadDeadline(t time.Time) error { return nil }
func (noConn) SetWriteDeadline(t time.Time) error { return nil }
-func (cfg *Config) WrapTCP(lis net.Listener) net.Listener {
- if cfg == nil {
+func (c *Config) WrapTCP(lis net.Listener) net.Listener {
+ if c == nil {
return lis
}
return &TCPListener{
- acl: cfg,
+ acl: c,
lis: lis,
}
}
diff --git a/internal/api/v1/certapi/renew.go b/internal/api/v1/certapi/renew.go
index b274ef4..084db82 100644
--- a/internal/api/v1/certapi/renew.go
+++ b/internal/api/v1/certapi/renew.go
@@ -3,9 +3,9 @@ package certapi
import (
"net/http"
+ "github.com/rs/zerolog/log"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
)
@@ -36,7 +36,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
gperr.LogError("failed to obtain cert", err)
_ = gpwebsocket.WriteText(conn, err.Error())
} else {
- logging.Info().Msg("cert obtained successfully")
+ log.Info().Msg("cert obtained successfully")
}
}()
for {
diff --git a/internal/api/v1/dockerapi/logs.go b/internal/api/v1/dockerapi/logs.go
index 385b59e..c313fed 100644
--- a/internal/api/v1/dockerapi/logs.go
+++ b/internal/api/v1/dockerapi/logs.go
@@ -9,12 +9,13 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/pkg/stdcopy"
"github.com/gorilla/websocket"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/task"
)
+// FIXME: agent logs not updating.
func Logs(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
server := r.PathValue("server")
@@ -68,7 +69,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) {
return
}
- logging.Err(err).
+ log.Err(err).
Str("server", server).
Str("container", containerID).
Msg("failed to de-multiplex logs")
diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go
index 5d3d022..0fc0335 100644
--- a/internal/auth/oauth_refresh.go
+++ b/internal/auth/oauth_refresh.go
@@ -11,9 +11,9 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/jsonstore"
- "github.com/yusing/go-proxy/internal/logging"
"golang.org/x/oauth2"
)
@@ -108,11 +108,11 @@ func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
RefreshToken: token,
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
})
- logging.Debug().Str("username", username).Msg("stored oauth refresh token")
+ log.Debug().Str("username", username).Msg("stored oauth refresh token")
}
func invalidateOAuthRefreshToken(sessionID sessionID) {
- logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
+ log.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
oauthRefreshTokens.Delete(string(sessionID))
}
@@ -127,7 +127,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
signed, err := jwtToken.SignedString(common.APIJWTSecret)
if err != nil {
- logging.Err(err).Msg("failed to sign session token")
+ log.Err(err).Msg("failed to sign session token")
return
}
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
@@ -190,7 +190,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
return nil, refreshToken.err
}
- idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken)
+ idTokenJWT, idToken, err := auth.getIDToken(ctx, newToken)
if err != nil {
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
return nil, refreshToken.err
@@ -205,7 +205,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
sessionID := newSessionID()
- logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
+ log.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
refreshToken.result = &RefreshResult{
diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go
index 4c8c0b2..7f9d4ed 100644
--- a/internal/auth/oidc.go
+++ b/internal/auth/oidc.go
@@ -12,9 +12,9 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/utils"
"golang.org/x/oauth2"
@@ -38,8 +38,8 @@ type (
const (
CookieOauthState = "godoxy_oidc_state"
- CookieOauthToken = "godoxy_oauth_token"
- CookieOauthSessionToken = "godoxy_session_token"
+ CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
+ CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
)
const (
@@ -79,7 +79,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
if err != nil && provider.EndSessionEndpoint() != "" {
// non critical, just warn
- logging.Warn().
+ log.Warn().
Str("issuer", issuerURL).
Err(err).
Msg("failed to parse end session URL")
@@ -129,7 +129,7 @@ func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
}
-func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
+func (auth *OIDCProvider) getIDToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
if !ok {
return "", nil, errMissingIDToken
@@ -176,7 +176,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
return
}
// clear cookies then redirect to home
- logging.Err(err).Msg("failed to refresh token")
+ log.Err(err).Msg("failed to refresh token")
auth.clearCookie(w, r)
http.Redirect(w, r, "/", http.StatusFound)
return
@@ -201,11 +201,12 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
userAllowed := slices.Contains(auth.allowedUsers, user)
- if !userAllowed {
- return false
+ if userAllowed {
+ return true
}
if len(auth.allowedGroups) == 0 {
- return true
+ // user is not allowed, but no groups are allowed
+ return false
}
return len(utils.Intersect(groups, auth.allowedGroups)) > 0
}
@@ -257,7 +258,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
return
}
- idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token)
+ idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token)
if err != nil {
gphttp.ServerError(w, r, err)
return
diff --git a/internal/autocert/config.go b/internal/autocert/config.go
index 0934d96..a71d429 100644
--- a/internal/autocert/config.go
+++ b/internal/autocert/config.go
@@ -5,13 +5,16 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
+ "net/http"
"os"
"regexp"
"github.com/go-acme/lego/v4/certcrypto"
+ "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
+ "github.com/rs/zerolog/log"
+ "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
)
@@ -22,13 +25,19 @@ type Config struct {
KeyPath string `json:"key_path,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"`
Provider string `json:"provider,omitempty"`
+ CADirURL string `json:"ca_dir_url,omitempty"`
Options map[string]any `json:"options,omitempty"`
+
+ HTTPClient *http.Client `json:"-"` // for tests only
+
+ challengeProvider challenge.Provider
}
var (
ErrMissingDomain = gperr.New("missing field 'domains'")
ErrMissingEmail = gperr.New("missing field 'email'")
ErrMissingProvider = gperr.New("missing field 'provider'")
+ ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
ErrInvalidDomain = gperr.New("invalid domain")
ErrUnknownProvider = gperr.New("unknown provider")
)
@@ -36,6 +45,7 @@ var (
const (
ProviderLocal = "local"
ProviderPseudo = "pseudo"
+ ProviderCustom = "custom"
)
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
@@ -52,6 +62,10 @@ func (cfg *Config) Validate() gperr.Error {
}
b := gperr.NewBuilder("autocert errors")
+ if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
+ b.Add(ErrMissingCADirURL)
+ }
+
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
@@ -59,24 +73,34 @@ func (cfg *Config) Validate() gperr.Error {
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
- for i, d := range cfg.Domains {
- if !domainOrWildcardRE.MatchString(d) {
- b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
+ if cfg.Provider != ProviderCustom {
+ for i, d := range cfg.Domains {
+ if !domainOrWildcardRE.MatchString(d) {
+ b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
+ }
}
}
// check if provider is implemented
providerConstructor, ok := Providers[cfg.Provider]
if !ok {
- b.Add(ErrUnknownProvider.
- Subject(cfg.Provider).
- With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
+ if cfg.Provider != ProviderCustom {
+ b.Add(ErrUnknownProvider.
+ Subject(cfg.Provider).
+ With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
+ }
} else {
- _, err := providerConstructor(cfg.Options)
+ provider, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
+ } else {
+ cfg.challengeProvider = provider
}
}
}
+
+ if cfg.challengeProvider == nil {
+ cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
+ }
return b.Error()
}
@@ -100,8 +124,7 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if privKey, err = cfg.LoadACMEKey(); err != nil {
- logging.Info().Err(err).Msg("load ACME private key failed")
- logging.Info().Msg("generate new ACME private key")
+ log.Info().Err(err).Msg("failed to load ACME private key, generating a now one")
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, gperr.New("generate ACME private key").With(err)
@@ -118,12 +141,23 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
}
legoCfg := lego.NewConfig(user)
- legoCfg.Certificate.KeyType = certcrypto.RSA2048
+ legoCfg.Certificate.KeyType = certcrypto.EC256
+
+ if cfg.HTTPClient != nil {
+ legoCfg.HTTPClient = cfg.HTTPClient
+ }
+
+ if cfg.CADirURL != "" {
+ legoCfg.CADirURL = cfg.CADirURL
+ }
return user, legoCfg, nil
}
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
+ if common.IsTest {
+ return nil, os.ErrNotExist
+ }
data, err := os.ReadFile(cfg.ACMEKeyPath)
if err != nil {
return nil, err
@@ -132,6 +166,9 @@ func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
}
func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error {
+ if common.IsTest {
+ return nil
+ }
data, err := x509.MarshalECPrivateKey(key)
if err != nil {
return err
diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go
index 8c2585b..12c0125 100644
--- a/internal/autocert/provider.go
+++ b/internal/autocert/provider.go
@@ -5,18 +5,20 @@ import (
"crypto/x509"
"errors"
"fmt"
+ "maps"
"os"
"path"
- "reflect"
- "sort"
+ "slices"
+ "strings"
"time"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+ "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
@@ -76,13 +78,11 @@ func (p *Provider) ObtainCert() error {
}
if p.cfg.Provider == ProviderPseudo {
- t := time.NewTicker(1000 * time.Millisecond)
- defer t.Stop()
- logging.Info().Msg("init client for pseudo provider")
- <-t.C
- logging.Info().Msg("registering acme for pseudo provider")
- <-t.C
- logging.Info().Msg("obtained cert for pseudo provider")
+ log.Info().Msg("init client for pseudo provider")
+ <-time.After(time.Second)
+ log.Info().Msg("registering acme for pseudo provider")
+ <-time.After(time.Second)
+ log.Info().Msg("obtained cert for pseudo provider")
return nil
}
@@ -107,7 +107,7 @@ func (p *Provider) ObtainCert() error {
})
if err != nil {
p.legoCert = nil
- logging.Err(err).Msg("cert renew failed, fallback to obtain")
+ log.Err(err).Msg("cert renew failed, fallback to obtain")
} else {
p.legoCert = cert
}
@@ -154,7 +154,7 @@ func (p *Provider) LoadCert() error {
p.tlsCert = &cert
p.certExpiries = expiries
- logging.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
+ log.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
}
@@ -219,13 +219,7 @@ func (p *Provider) initClient() error {
return err
}
- generator := Providers[p.cfg.Provider]
- legoProvider, pErr := generator(p.cfg.Options)
- if pErr != nil {
- return pErr
- }
-
- err = legoClient.Challenge.SetDNS01Provider(legoProvider)
+ err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider)
if err != nil {
return err
}
@@ -240,7 +234,7 @@ func (p *Provider) registerACME() error {
}
if reg, err := p.client.Registration.ResolveAccountByKey(); err == nil {
p.user.Registration = reg
- logging.Info().Msg("reused acme registration from private key")
+ log.Info().Msg("reused acme registration from private key")
return nil
}
@@ -249,11 +243,14 @@ func (p *Provider) registerACME() error {
return err
}
p.user.Registration = reg
- logging.Info().Interface("reg", reg).Msg("acme registered")
+ log.Info().Interface("reg", reg).Msg("acme registered")
return nil
}
func (p *Provider) saveCert(cert *certificate.Resource) error {
+ if common.IsTest {
+ return nil
+ }
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
@@ -283,22 +280,19 @@ func (p *Provider) certState() CertState {
return CertStateExpired
}
- certDomains := make([]string, len(p.certExpiries))
- wantedDomains := make([]string, len(p.cfg.Domains))
- i := 0
- for domain := range p.certExpiries {
- certDomains[i] = domain
- i++
- }
- copy(wantedDomains, p.cfg.Domains)
- sort.Strings(wantedDomains)
- sort.Strings(certDomains)
-
- if !reflect.DeepEqual(certDomains, wantedDomains) {
- logging.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
+ if len(p.certExpiries) != len(p.cfg.Domains) {
return CertStateMismatch
}
+ for i := range len(p.cfg.Domains) {
+ if _, ok := p.certExpiries[p.cfg.Domains[i]]; !ok {
+ log.Info().Msgf("autocert domains mismatch: cert: %s, wanted: %s",
+ strings.Join(slices.Collect(maps.Keys(p.certExpiries)), ", "),
+ strings.Join(p.cfg.Domains, ", "))
+ return CertStateMismatch
+ }
+ }
+
return CertStateValid
}
@@ -309,9 +303,9 @@ func (p *Provider) renewIfNeeded() error {
switch p.certState() {
case CertStateExpired:
- logging.Info().Msg("certs expired, renewing")
+ log.Info().Msg("certs expired, renewing")
case CertStateMismatch:
- logging.Info().Msg("cert domains mismatch with config, renewing")
+ log.Info().Msg("cert domains mismatch with config, renewing")
default:
return nil
}
diff --git a/internal/autocert/provider_test/custom_test.go b/internal/autocert/provider_test/custom_test.go
new file mode 100644
index 0000000..6490761
--- /dev/null
+++ b/internal/autocert/provider_test/custom_test.go
@@ -0,0 +1,453 @@
+//nolint:errchkjson,errcheck
+package provider_test
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/base64"
+ "encoding/json"
+ "encoding/pem"
+ "io"
+ "math/big"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "github.com/yusing/go-proxy/internal/autocert"
+ "github.com/yusing/go-proxy/internal/dnsproviders"
+)
+
+func TestMain(m *testing.M) {
+ dnsproviders.InitProviders()
+ m.Run()
+}
+
+func TestCustomProvider(t *testing.T) {
+ t.Run("valid custom provider with step-ca", func(t *testing.T) {
+ cfg := &autocert.Config{
+ Email: "test@example.com",
+ Domains: []string{"example.com", "*.example.com"},
+ Provider: autocert.ProviderCustom,
+ CADirURL: "https://ca.example.com:9000/acme/acme/directory",
+ CertPath: "certs/custom.crt",
+ KeyPath: "certs/custom.key",
+ ACMEKeyPath: "certs/custom-acme.key",
+ }
+
+ err := cfg.Validate()
+ require.NoError(t, err)
+
+ user, legoCfg, err := cfg.GetLegoConfig()
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.NotNil(t, legoCfg)
+ require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", legoCfg.CADirURL)
+ require.Equal(t, "test@example.com", user.Email)
+ })
+
+ t.Run("custom provider missing CADirURL", func(t *testing.T) {
+ cfg := &autocert.Config{
+ Email: "test@example.com",
+ Domains: []string{"example.com"},
+ Provider: autocert.ProviderCustom,
+ // CADirURL is missing
+ }
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
+ })
+
+ t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
+ cfg := &autocert.Config{
+ Email: "admin@internal.com",
+ Domains: []string{"internal.example.com", "api.internal.example.com"},
+ Provider: autocert.ProviderCustom,
+ CADirURL: "https://step-ca.internal:443/acme/acme/directory",
+ CertPath: "certs/internal.crt",
+ KeyPath: "certs/internal.key",
+ ACMEKeyPath: "certs/internal-acme.key",
+ }
+
+ err := cfg.Validate()
+ require.NoError(t, err)
+
+ user, legoCfg, err := cfg.GetLegoConfig()
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.NotNil(t, legoCfg)
+ require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
+ require.Equal(t, "admin@internal.com", user.Email)
+
+ provider := autocert.NewProvider(cfg, user, legoCfg)
+ require.NotNil(t, provider)
+ require.Equal(t, autocert.ProviderCustom, provider.GetName())
+ require.Equal(t, "certs/internal.crt", provider.GetCertPath())
+ require.Equal(t, "certs/internal.key", provider.GetKeyPath())
+ })
+}
+
+func TestObtainCertFromCustomProvider(t *testing.T) {
+ // Create a test ACME server
+ acmeServer := newTestACMEServer(t)
+ defer acmeServer.Close()
+
+ t.Run("obtain cert from custom step-ca server", func(t *testing.T) {
+ cfg := &autocert.Config{
+ Email: "test@example.com",
+ Domains: []string{"test.example.com"},
+ Provider: autocert.ProviderCustom,
+ CADirURL: acmeServer.URL() + "/acme/acme/directory",
+ CertPath: "certs/stepca-test.crt",
+ KeyPath: "certs/stepca-test.key",
+ ACMEKeyPath: "certs/stepca-test-acme.key",
+ HTTPClient: acmeServer.httpClient(),
+ }
+
+ err := error(cfg.Validate())
+ require.NoError(t, err)
+
+ user, legoCfg, err := cfg.GetLegoConfig()
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.NotNil(t, legoCfg)
+
+ provider := autocert.NewProvider(cfg, user, legoCfg)
+ require.NotNil(t, provider)
+
+ // Test obtaining certificate
+ err = provider.ObtainCert()
+ require.NoError(t, err)
+
+ // Verify certificate was obtained
+ cert, err := provider.GetCert(nil)
+ require.NoError(t, err)
+ require.NotNil(t, cert)
+
+ // Verify certificate properties
+ x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
+ require.NoError(t, err)
+ require.Contains(t, x509Cert.DNSNames, "test.example.com")
+ require.True(t, time.Now().Before(x509Cert.NotAfter))
+ require.True(t, time.Now().After(x509Cert.NotBefore))
+ })
+}
+
+// testACMEServer implements a minimal ACME server for testing.
+type testACMEServer struct {
+ server *httptest.Server
+ caCert *x509.Certificate
+ caKey *rsa.PrivateKey
+ clientCSRs map[string]*x509.CertificateRequest
+ orderID string
+}
+
+func newTestACMEServer(t *testing.T) *testACMEServer {
+ t.Helper()
+
+ // Generate CA certificate and key
+ caKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ caTemplate := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test CA"},
+ Country: []string{"US"},
+ Province: []string{""},
+ Locality: []string{"Test"},
+ StreetAddress: []string{""},
+ PostalCode: []string{""},
+ },
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(365 * 24 * time.Hour),
+ IsCA: true,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ BasicConstraintsValid: true,
+ }
+
+ caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
+ require.NoError(t, err)
+
+ caCert, err := x509.ParseCertificate(caCertDER)
+ require.NoError(t, err)
+
+ acme := &testACMEServer{
+ caCert: caCert,
+ caKey: caKey,
+ clientCSRs: make(map[string]*x509.CertificateRequest),
+ orderID: "test-order-123",
+ }
+
+ mux := http.NewServeMux()
+ acme.setupRoutes(mux)
+
+ acme.server = httptest.NewTLSServer(mux)
+ return acme
+}
+
+func (s *testACMEServer) Close() {
+ s.server.Close()
+}
+
+func (s *testACMEServer) URL() string {
+ return s.server.URL
+}
+
+func (s *testACMEServer) httpClient() *http.Client {
+ return &http.Client{
+ Transport: &http.Transport{
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }).DialContext,
+ TLSHandshakeTimeout: 30 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true, //nolint:gosec
+ },
+ },
+ }
+}
+
+func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
+ // ACME directory endpoint
+ mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
+
+ // ACME endpoints
+ mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
+ mux.HandleFunc("/acme/new-account", s.handleNewAccount)
+ mux.HandleFunc("/acme/new-order", s.handleNewOrder)
+ mux.HandleFunc("/acme/authz/", s.handleAuthorization)
+ mux.HandleFunc("/acme/chall/", s.handleChallenge)
+ mux.HandleFunc("/acme/order/", s.handleOrder)
+ mux.HandleFunc("/acme/cert/", s.handleCertificate)
+}
+
+func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
+ directory := map[string]interface{}{
+ "newNonce": s.server.URL + "/acme/new-nonce",
+ "newAccount": s.server.URL + "/acme/new-account",
+ "newOrder": s.server.URL + "/acme/new-order",
+ "keyChange": s.server.URL + "/acme/key-change",
+ "meta": map[string]interface{}{
+ "termsOfService": s.server.URL + "/terms",
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(directory)
+}
+
+func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Replay-Nonce", "test-nonce-12345")
+ w.WriteHeader(http.StatusOK)
+}
+
+func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
+ account := map[string]interface{}{
+ "status": "valid",
+ "contact": []string{"mailto:test@example.com"},
+ "orders": s.server.URL + "/acme/orders",
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Location", s.server.URL+"/acme/account/1")
+ w.Header().Set("Replay-Nonce", "test-nonce-67890")
+ w.WriteHeader(http.StatusCreated)
+ json.NewEncoder(w).Encode(account)
+}
+
+func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
+ authzID := "test-authz-456"
+
+ order := map[string]interface{}{
+ "status": "ready", // Skip pending state for simplicity
+ "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
+ "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
+ "authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
+ "finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
+ w.Header().Set("Replay-Nonce", "test-nonce-order")
+ w.WriteHeader(http.StatusCreated)
+ json.NewEncoder(w).Encode(order)
+}
+
+func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
+ authz := map[string]interface{}{
+ "status": "valid", // Skip challenge validation for simplicity
+ "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
+ "identifier": map[string]string{"type": "dns", "value": "test.example.com"},
+ "challenges": []map[string]interface{}{
+ {
+ "type": "dns-01",
+ "status": "valid",
+ "url": s.server.URL + "/acme/chall/test-chall-789",
+ "token": "test-token-abc123",
+ },
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Replay-Nonce", "test-nonce-authz")
+ json.NewEncoder(w).Encode(authz)
+}
+
+func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
+ challenge := map[string]interface{}{
+ "type": "dns-01",
+ "status": "valid",
+ "url": r.URL.String(),
+ "token": "test-token-abc123",
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Replay-Nonce", "test-nonce-chall")
+ json.NewEncoder(w).Encode(challenge)
+}
+
+func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
+ if strings.HasSuffix(r.URL.Path, "/finalize") {
+ s.handleFinalize(w, r)
+ return
+ }
+
+ certURL := s.server.URL + "/acme/cert/" + s.orderID
+ order := map[string]interface{}{
+ "status": "valid",
+ "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
+ "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
+ "certificate": certURL,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Replay-Nonce", "test-nonce-order-get")
+ json.NewEncoder(w).Encode(order)
+}
+
+func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
+ // Read the JWS payload
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, "Failed to read request", http.StatusBadRequest)
+ return
+ }
+
+ // Extract CSR from JWS payload
+ csr, err := s.extractCSRFromJWS(body)
+ if err != nil {
+ http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ // Store the CSR for certificate generation
+ s.clientCSRs[s.orderID] = csr
+
+ certURL := s.server.URL + "/acme/cert/" + s.orderID
+ order := map[string]interface{}{
+ "status": "valid",
+ "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
+ "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
+ "certificate": certURL,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
+ w.Header().Set("Replay-Nonce", "test-nonce-finalize")
+ json.NewEncoder(w).Encode(order)
+}
+
+func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
+ // Parse the JWS structure
+ var jws struct {
+ Protected string `json:"protected"`
+ Payload string `json:"payload"`
+ Signature string `json:"signature"`
+ }
+
+ if err := json.Unmarshal(jwsData, &jws); err != nil {
+ return nil, err
+ }
+
+ // Decode the payload
+ payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the finalize request
+ var finalizeReq struct {
+ CSR string `json:"csr"`
+ }
+
+ if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
+ return nil, err
+ }
+
+ // Decode the CSR
+ csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the CSR
+ csr, err := x509.ParseCertificateRequest(csrBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ return csr, nil
+}
+
+func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
+ // Extract order ID from URL
+ orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
+
+ // Get the CSR for this order
+ csr, exists := s.clientCSRs[orderID]
+ if !exists {
+ http.Error(w, "No CSR found for order", http.StatusBadRequest)
+ return
+ }
+
+ // Create certificate using the public key from the client's CSR
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(2),
+ Subject: pkix.Name{
+ Organization: []string{"Test Cert"},
+ Country: []string{"US"},
+ },
+ DNSNames: csr.DNSNames,
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(90 * 24 * time.Hour),
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+
+ // Use the public key from the CSR and sign with CA key
+ certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ // Return certificate chain
+ certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
+ caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
+
+ w.Header().Set("Content-Type", "application/pem-certificate-chain")
+ w.Header().Set("Replay-Nonce", "test-nonce-cert")
+ w.Write(append(certPEM, caPEM...))
+}
diff --git a/internal/autocert/provider_test/ovh_test.go b/internal/autocert/provider_test/ovh_test.go
index 203e9ef..e268af4 100644
--- a/internal/autocert/provider_test/ovh_test.go
+++ b/internal/autocert/provider_test/ovh_test.go
@@ -6,7 +6,7 @@ import (
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/require"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
// type Config struct {
@@ -45,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
- require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
+ require.NoError(t, serialization.MapUnmarshalValidate(opt, cfg))
require.Equal(t, cfgExpected, cfg)
}
diff --git a/internal/autocert/providers.go b/internal/autocert/providers.go
index dbbcb6d..5b73e39 100644
--- a/internal/autocert/providers.go
+++ b/internal/autocert/providers.go
@@ -3,7 +3,7 @@ package autocert
import (
"github.com/go-acme/lego/v4/challenge"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type Generator func(map[string]any) (challenge.Provider, gperr.Error)
@@ -16,9 +16,11 @@ func DNSProvider[CT any, PT challenge.Provider](
) Generator {
return func(opt map[string]any) (challenge.Provider, gperr.Error) {
cfg := defaultCfg()
- err := utils.MapUnmarshalValidate(opt, &cfg)
- if err != nil {
- return nil, err
+ if len(opt) > 0 {
+ err := serialization.MapUnmarshalValidate(opt, &cfg)
+ if err != nil {
+ return nil, err
+ }
}
p, pErr := newProvider(cfg)
return p, gperr.Wrap(pErr)
diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go
index b436bed..62c589e 100644
--- a/internal/autocert/setup.go
+++ b/internal/autocert/setup.go
@@ -4,7 +4,7 @@ import (
"errors"
"os"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -13,14 +13,14 @@ func (p *Provider) Setup() (err error) {
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
- logging.Debug().Msg("obtaining cert due to error loading cert")
+ log.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
}
for _, expiry := range p.GetExpiries() {
- logging.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
+ log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
break
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 621cfa0..02cac28 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -10,20 +10,20 @@ import (
"time"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/api"
autocert "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/entrypoint"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
proxy "github.com/yusing/go-proxy/internal/route/provider"
+ "github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task"
- "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
"github.com/yusing/go-proxy/internal/watcher"
@@ -96,10 +96,10 @@ func OnConfigChange(ev []events.Event) {
// just reload once and check the last event
switch ev[len(ev)-1].Action {
case events.ActionFileRenamed:
- logging.Warn().Msg(cfgRenameWarn)
+ log.Warn().Msg(cfgRenameWarn)
return
case events.ActionFileDeleted:
- logging.Warn().Msg(cfgDeleteWarn)
+ log.Warn().Msg(cfgDeleteWarn)
return
}
@@ -161,7 +161,7 @@ func (cfg *Config) Start(opts ...*StartServersOptions) {
func (cfg *Config) StartAutoCert() {
autocert := cfg.autocertProvider
if autocert == nil {
- logging.Info().Msg("autocert not configured")
+ log.Info().Msg("autocert not configured")
return
}
@@ -223,7 +223,7 @@ func (cfg *Config) load() gperr.Error {
}
model := config.DefaultConfig()
- if err := utils.UnmarshalValidateYAML(data, model); err != nil {
+ if err := serialization.UnmarshalValidateYAML(data, model); err != nil {
gperr.LogFatal(errMsg, err)
}
@@ -374,6 +374,6 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
}
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
})
- logging.Info().Msg(results.String())
+ log.Info().Msg(results.String())
return errs.Error()
}
diff --git a/internal/config/types/config.go b/internal/config/types/config.go
index 4460025..80fa48e 100644
--- a/internal/config/types/config.go
+++ b/internal/config/types/config.go
@@ -14,7 +14,7 @@ import (
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type (
@@ -93,14 +93,14 @@ func HasInstance() bool {
func Validate(data []byte) gperr.Error {
var model Config
- return utils.UnmarshalValidateYAML(data, &model)
+ return serialization.UnmarshalValidateYAML(data, &model)
}
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
func init() {
- utils.RegisterDefaultValueFactory(DefaultConfig)
- utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
+ serialization.RegisterDefaultValueFactory(DefaultConfig)
+ serialization.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
domains := fl.Field().Interface().([]string)
for _, domain := range domains {
if !matchDomainsRegex.MatchString(domain) {
@@ -109,7 +109,7 @@ func init() {
}
return true
})
- utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
+ serialization.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
m := fl.Field().Interface().(map[string]string)
for k := range m {
if k == "" {
diff --git a/internal/dnsproviders/go.mod b/internal/dnsproviders/go.mod
index 5506b95..65efe61 100644
--- a/internal/dnsproviders/go.mod
+++ b/internal/dnsproviders/go.mod
@@ -4,6 +4,8 @@ go 1.24.3
replace github.com/yusing/go-proxy => ../..
+replace github.com/yusing/go-proxy/internal/utils => ../utils
+
require (
github.com/go-acme/lego/v4 v4.23.1
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
@@ -156,6 +158,7 @@ require (
github.com/vultr/govultr/v3 v3.20.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
+ github.com/yusing/go-proxy/internal/utils v0.0.0 // indirect
go.mongodb.org/mongo-driver v1.17.3 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
diff --git a/internal/docker/client.go b/internal/docker/client.go
index afadf20..6d15d88 100644
--- a/internal/docker/client.go
+++ b/internal/docker/client.go
@@ -13,13 +13,14 @@ import (
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
+// TODO: implement reconnect here.
type (
SharedClient struct {
*client.Client
@@ -83,7 +84,7 @@ func closeTimedOutClients() {
if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs {
delete(clientMap, c.Key())
c.Client.Close()
- logging.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
+ log.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
}
}
}
@@ -148,7 +149,7 @@ func NewClient(host string) (*SharedClient, error) {
default:
helper, err := connhelper.GetConnectionHelper(host)
if err != nil {
- logging.Panic().Err(err).Msg("failed to get connection helper")
+ log.Panic().Err(err).Msg("failed to get connection helper")
}
if helper != nil {
httpClient := &http.Client{
@@ -189,10 +190,10 @@ func NewClient(host string) (*SharedClient, error) {
c.dial = client.Dialer()
}
if c.addr == "" {
- c.addr = c.Client.DaemonHost()
+ c.addr = c.DaemonHost()
}
- defer logging.Debug().Str("host", host).Msg("docker client initialized")
+ defer log.Debug().Str("host", host).Msg("docker client initialized")
clientMap[c.Key()] = c
return c, nil
diff --git a/internal/docker/container.go b/internal/docker/container.go
index 6044b30..ddc4ac6 100644
--- a/internal/docker/container.go
+++ b/internal/docker/container.go
@@ -9,11 +9,12 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/docker/go-connections/nat"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/utils"
)
@@ -90,7 +91,7 @@ func FromDocker(c *container.SummaryTrimmed, dockerHost string) (res *Container)
var ok bool
res.Agent, ok = config.GetInstance().GetAgent(dockerHost)
if !ok {
- logging.Error().Msgf("agent %q not found", dockerHost)
+ log.Error().Msgf("agent %q not found", dockerHost)
}
}
@@ -183,7 +184,7 @@ func (c *Container) setPublicHostname() {
}
url, err := url.Parse(c.DockerHost)
if err != nil {
- logging.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
+ log.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
c.PublicHostname = "127.0.0.1"
return
}
@@ -224,7 +225,7 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
ContainerName: c.ContainerName,
},
}
- err := utils.MapUnmarshalValidate(cfg, idwCfg)
+ err := serialization.MapUnmarshalValidate(cfg, idwCfg)
if err != nil {
gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err))
} else {
diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go
index e3683db..aa35b9c 100644
--- a/internal/entrypoint/entrypoint.go
+++ b/internal/entrypoint/entrypoint.go
@@ -6,7 +6,7 @@ import (
"net/http"
"strings"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/logging/accesslog"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
@@ -50,7 +50,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
}
ep.middleware = mid
- logging.Debug().Msg("entrypoint middleware loaded")
+ log.Debug().Msg("entrypoint middleware loaded")
return nil
}
@@ -64,7 +64,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request
if err != nil {
return
}
- logging.Debug().Msg("entrypoint access logger created")
+ log.Debug().Msg("entrypoint access logger created")
return
}
@@ -89,7 +89,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
- logging.Err(err).
+ log.Err(err).
Str("method", r.Method).
Str("url", r.URL.String()).
Str("remote", r.RemoteAddr).
@@ -99,7 +99,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if _, err := w.Write(errorPage); err != nil {
- logging.Err(err).Msg("failed to write error page")
+ log.Err(err).Msg("failed to write error page")
}
} else {
http.Error(w, err.Error(), http.StatusNotFound)
diff --git a/internal/gperr/log.go b/internal/gperr/log.go
index 5be1bce..a13336b 100644
--- a/internal/gperr/log.go
+++ b/internal/gperr/log.go
@@ -4,8 +4,8 @@ import (
"os"
"github.com/rs/zerolog"
+ zerologlog "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
- "github.com/yusing/go-proxy/internal/logging"
)
func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
@@ -13,7 +13,7 @@ func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger)
if len(logger) > 0 {
l = logger[0]
} else {
- l = logging.GetLogger()
+ l = &zerologlog.Logger
}
l.WithLevel(level).Msg(New(highlightANSI(msg)).With(err).Error())
switch level {
diff --git a/internal/homepage/homepage.go b/internal/homepage/homepage.go
index 2d9bcc3..13d78bb 100644
--- a/internal/homepage/homepage.go
+++ b/internal/homepage/homepage.go
@@ -5,7 +5,7 @@ import (
"slices"
"github.com/yusing/go-proxy/internal/homepage/widgets"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type (
@@ -32,7 +32,7 @@ type (
)
func init() {
- utils.RegisterDefaultValueFactory(func() *ItemConfig {
+ serialization.RegisterDefaultValueFactory(func() *ItemConfig {
return &ItemConfig{
Show: true,
}
diff --git a/internal/homepage/icon_cache.go b/internal/homepage/icon_cache.go
index 4f0f658..2a4f5c2 100644
--- a/internal/homepage/icon_cache.go
+++ b/internal/homepage/icon_cache.go
@@ -6,9 +6,9 @@ import (
"sync"
"time"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/jsonstore"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/atomic"
@@ -74,7 +74,7 @@ func pruneExpiredIconCache() {
}
}
if nPruned > 0 {
- logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
+ log.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
}
}
@@ -87,7 +87,7 @@ func loadIconCache(key string) *FetchResult {
defer iconMu.RUnlock()
icon, ok := iconCache.Load(key)
if ok && len(icon.Icon) > 0 {
- logging.Debug().
+ log.Debug().
Str("key", key).
Msg("icon found in cache")
icon.LastAccess.Store(utils.TimeNow())
@@ -99,7 +99,7 @@ func loadIconCache(key string) *FetchResult {
func storeIconCache(key string, result *FetchResult) {
icon := result.Icon
if len(icon) > maxIconSize {
- logging.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
+ log.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
return
}
@@ -109,7 +109,7 @@ func storeIconCache(key string, result *FetchResult) {
entry := &cacheEntry{Icon: icon, ContentType: result.contentType}
entry.LastAccess.Store(time.Now())
iconCache.Store(key, entry)
- logging.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
+ log.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
}
func (e *cacheEntry) IsExpired() bool {
diff --git a/internal/homepage/list_icons.go b/internal/homepage/list_icons.go
index 734282d..72497a8 100644
--- a/internal/homepage/list_icons.go
+++ b/internal/homepage/list_icons.go
@@ -1,6 +1,7 @@
package homepage
import (
+ "context"
"encoding/json"
"fmt"
"io"
@@ -10,10 +11,10 @@ import (
"time"
"github.com/lithammer/fuzzysearch/fuzzy"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task"
- "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -46,30 +47,30 @@ type (
func (icon *IconMeta) Filenames(ref string) []string {
filenames := make([]string, 0)
if icon.SVG {
- filenames = append(filenames, fmt.Sprintf("%s.svg", ref))
+ filenames = append(filenames, ref+".svg")
if icon.Light {
- filenames = append(filenames, fmt.Sprintf("%s-light.svg", ref))
+ filenames = append(filenames, ref+"-light.svg")
}
if icon.Dark {
- filenames = append(filenames, fmt.Sprintf("%s-dark.svg", ref))
+ filenames = append(filenames, ref+"-dark.svg")
}
}
if icon.PNG {
- filenames = append(filenames, fmt.Sprintf("%s.png", ref))
+ filenames = append(filenames, ref+".png")
if icon.Light {
- filenames = append(filenames, fmt.Sprintf("%s-light.png", ref))
+ filenames = append(filenames, ref+"-light.png")
}
if icon.Dark {
- filenames = append(filenames, fmt.Sprintf("%s-dark.png", ref))
+ filenames = append(filenames, ref+"-dark.png")
}
}
if icon.WebP {
- filenames = append(filenames, fmt.Sprintf("%s.webp", ref))
+ filenames = append(filenames, ref+".webp")
if icon.Light {
- filenames = append(filenames, fmt.Sprintf("%s-light.webp", ref))
+ filenames = append(filenames, ref+"-light.webp")
}
if icon.Dark {
- filenames = append(filenames, fmt.Sprintf("%s-dark.webp", ref))
+ filenames = append(filenames, ref+"-dark.webp")
}
}
return filenames
@@ -99,21 +100,21 @@ func InitIconListCache() {
iconsCache.Lock()
defer iconsCache.Unlock()
- err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache)
+ err := serialization.LoadJSONIfExist(common.IconListCachePath, iconsCache)
if err != nil {
- logging.Error().Err(err).Msg("failed to load icons")
+ log.Error().Err(err).Msg("failed to load icons")
} else if len(iconsCache.Icons) > 0 {
- logging.Info().
+ log.Info().
Int("icons", len(iconsCache.Icons)).
Msg("icons loaded")
}
if err = updateIcons(); err != nil {
- logging.Error().Err(err).Msg("failed to update icons")
+ log.Error().Err(err).Msg("failed to update icons")
}
task.OnProgramExit("save_icons_cache", func() {
- utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
+ _ = serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
})
}
@@ -134,17 +135,17 @@ func ListAvailableIcons() (*Cache, error) {
iconsCache.Lock()
defer iconsCache.Unlock()
- logging.Info().Msg("updating icon data")
+ log.Info().Msg("updating icon data")
if err := updateIcons(); err != nil {
return nil, err
}
- logging.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
+ log.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
iconsCache.LastUpdate = time.Now()
- err := utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
+ err := serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
if err != nil {
- logging.Warn().Err(err).Msg("failed to save icons")
+ log.Warn().Err(err).Msg("failed to save icons")
}
return iconsCache, nil
}
@@ -230,14 +231,17 @@ func updateIcons() error {
var httpGet = httpGetImpl
-func MockHttpGet(body []byte) {
+func MockHTTPGet(body []byte) {
httpGet = func(_ string) ([]byte, error) {
return body, nil
}
}
func httpGetImpl(url string) ([]byte, error) {
- req, err := http.NewRequest(http.MethodGet, url, nil)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
@@ -347,7 +351,7 @@ func UpdateSelfhstIcons() error {
}
data := make([]SelfhStIcon, 0)
- err = json.Unmarshal(body, &data)
+ err = json.Unmarshal(body, &data) //nolint:musttag
if err != nil {
return err
}
diff --git a/internal/homepage/list_icons_test.go b/internal/homepage/list_icons_test.go
index 296c635..9569233 100644
--- a/internal/homepage/list_icons_test.go
+++ b/internal/homepage/list_icons_test.go
@@ -68,6 +68,8 @@ type testCases struct {
}
func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
+ t.Helper()
+
for _, item := range test {
icon, ok := iconsCache.Icons[item.Key]
if !ok {
@@ -89,7 +91,7 @@ func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
}
func TestListWalkxCodeIcons(t *testing.T) {
- MockHttpGet([]byte(walkxcodeIcons))
+ MockHTTPGet([]byte(walkxcodeIcons))
if err := UpdateWalkxCodeIcons(); err != nil {
t.Fatal(err)
}
@@ -122,7 +124,7 @@ func TestListWalkxCodeIcons(t *testing.T) {
}
func TestListSelfhstIcons(t *testing.T) {
- MockHttpGet([]byte(selfhstIcons))
+ MockHTTPGet([]byte(selfhstIcons))
if err := UpdateSelfhstIcons(); err != nil {
t.Fatal(err)
}
diff --git a/internal/homepage/widgets/widgets.go b/internal/homepage/widgets/widgets.go
index e653b69..6d62eb4 100644
--- a/internal/homepage/widgets/widgets.go
+++ b/internal/homepage/widgets/widgets.go
@@ -4,7 +4,7 @@ import (
"context"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type (
@@ -33,17 +33,18 @@ var widgetProviders = map[string]struct{}{
var ErrInvalidProvider = gperr.New("invalid provider")
func (cfg *Config) UnmarshalMap(m map[string]any) error {
- cfg.Provider = m["provider"].(string)
+ var ok bool
+ cfg.Provider, ok = m["provider"].(string)
+ if !ok {
+ return ErrInvalidProvider.Withf("non string")
+ }
if _, ok := widgetProviders[cfg.Provider]; !ok {
return ErrInvalidProvider.Subject(cfg.Provider)
}
delete(m, "provider")
- m, ok := m["config"].(map[string]any)
+ m, ok = m["config"].(map[string]any)
if !ok {
return gperr.New("invalid config")
}
- if err := utils.MapUnmarshalValidate(m, &cfg.Config); err != nil {
- return err
- }
- return nil
+ return serialization.MapUnmarshalValidate(m, &cfg.Config)
}
diff --git a/internal/idlewatcher/watcher.go b/internal/idlewatcher/watcher.go
index 02f84e1..4cf1139 100644
--- a/internal/idlewatcher/watcher.go
+++ b/internal/idlewatcher/watcher.go
@@ -7,10 +7,10 @@ import (
"time"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/idlewatcher/provider"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
@@ -73,13 +73,13 @@ var dummyHealthCheckConfig = &health.HealthCheckConfig{
}
var (
- causeReload = gperr.New("reloaded")
- causeContainerDestroy = gperr.New("container destroyed")
+ causeReload = gperr.New("reloaded") //nolint:errname
+ causeContainerDestroy = gperr.New("container destroyed") //nolint:errname
)
const reqTimeout = 3 * time.Second
-// TODO: fix stream type
+// TODO: fix stream type.
func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
cfg := r.IdlewatcherConfig()
key := cfg.Key()
@@ -120,7 +120,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
return nil, err
}
w.provider = p
- w.l = logging.With().
+ w.l = log.With().
Str("provider", providerType).
Str("container", cfg.ContainerName()).
Logger()
diff --git a/internal/jsonstore/jsonstore.go b/internal/jsonstore/jsonstore.go
index 54ff2ed..8fb0df2 100644
--- a/internal/jsonstore/jsonstore.go
+++ b/internal/jsonstore/jsonstore.go
@@ -2,18 +2,17 @@ package jsonstore
import (
"encoding/json"
+ "maps"
"os"
"path/filepath"
"reflect"
- "maps"
-
"github.com/puzpuzpuz/xsync/v4"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task"
- "github.com/yusing/go-proxy/internal/utils"
)
type namespace string
@@ -36,13 +35,15 @@ type store interface {
json.Unmarshaler
}
-var stores = make(map[namespace]store)
-var storesPath = common.DataDir
+var (
+ stores = make(map[namespace]store)
+ storesPath = common.DataDir
+)
func init() {
task.OnProgramExit("save_stores", func() {
if err := save(); err != nil {
- logging.Error().Err(err).Msg("failed to save stores")
+ log.Error().Err(err).Msg("failed to save stores")
}
})
}
@@ -54,20 +55,20 @@ func loadNS[T store](ns namespace) T {
file, err := os.Open(path)
if err != nil {
if !os.IsNotExist(err) {
- logging.Err(err).
+ log.Err(err).
Str("path", path).
Msg("failed to load store")
}
} else {
defer file.Close()
if err := json.NewDecoder(file).Decode(&store); err != nil {
- logging.Err(err).
+ log.Err(err).
Str("path", path).
Msg("failed to load store")
}
}
stores[ns] = store
- logging.Debug().
+ log.Debug().
Str("namespace", string(ns)).
Str("path", path).
Msg("loaded store")
@@ -77,7 +78,7 @@ func loadNS[T store](ns namespace) T {
func save() error {
errs := gperr.NewBuilder("failed to save data stores")
for ns, store := range stores {
- if err := utils.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
+ if err := serialization.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
errs.Add(err)
}
}
@@ -86,7 +87,7 @@ func save() error {
func Store[VT any](namespace namespace) MapStore[VT] {
if _, ok := stores[namespace]; ok {
- logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
+ log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
}
store := loadNS[*MapStore[VT]](namespace)
stores[namespace] = store
@@ -95,7 +96,7 @@ func Store[VT any](namespace namespace) MapStore[VT] {
func Object[Ptr Initializer](namespace namespace) Ptr {
if _, ok := stores[namespace]; ok {
- logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
+ log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
}
obj := loadNS[*ObjectStore[Ptr]](namespace)
stores[namespace] = obj
@@ -117,7 +118,7 @@ func (s *MapStore[VT]) UnmarshalJSON(data []byte) error {
}
s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp)))
for k, v := range tmp {
- s.Map.Store(k, v)
+ s.Store(k, v)
}
return nil
}
diff --git a/internal/logging/accesslog/access_logger.go b/internal/logging/accesslog/access_logger.go
index 8d6a83e..647181f 100644
--- a/internal/logging/accesslog/access_logger.go
+++ b/internal/logging/accesslog/access_logger.go
@@ -8,8 +8,8 @@ import (
"time"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
@@ -83,6 +83,9 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
if err != nil {
return nil, err
}
+ if io == nil {
+ return nil, nil //nolint:nilnil
+ }
return NewAccessLoggerWithIO(parent, io, cfg), nil
}
@@ -120,7 +123,7 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any
bufSize: MinBufferSize,
lineBufPool: synk.NewBytesPool(),
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
- logger: logging.With().Str("file", writer.Name()).Logger(),
+ logger: log.With().Str("file", writer.Name()).Logger(),
}
l.supportRotate = unwrap[supportRotate](writer)
@@ -181,7 +184,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := l.lineBufPool.Get()
defer l.lineBufPool.Put(line)
- line = l.ACLFormatter.AppendACLLog(line, info, blocked)
+ line = l.AppendACLLog(line, info, blocked)
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
@@ -194,7 +197,7 @@ func (l *AccessLogger) ShouldRotate() bool {
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
if !l.ShouldRotate() {
- return nil, nil
+ return nil, nil //nolint:nilnil
}
l.writer.Flush()
diff --git a/internal/logging/accesslog/config.go b/internal/logging/accesslog/config.go
index e9b90d5..9ac0c5d 100644
--- a/internal/logging/accesslog/config.go
+++ b/internal/logging/accesslog/config.go
@@ -4,7 +4,7 @@ import (
"time"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type (
@@ -126,6 +126,6 @@ func DefaultACLLoggerConfig() *ACLLoggerConfig {
}
func init() {
- utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
- utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
+ serialization.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
+ serialization.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
}
diff --git a/internal/logging/accesslog/config_test.go b/internal/logging/accesslog/config_test.go
index a44199d..37ccc54 100644
--- a/internal/logging/accesslog/config_test.go
+++ b/internal/logging/accesslog/config_test.go
@@ -5,7 +5,7 @@ import (
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/logging/accesslog"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) {
expect.NoError(t, err)
var config RequestLoggerConfig
- err = utils.MapUnmarshalValidate(parsed, &config)
+ err = serialization.MapUnmarshalValidate(parsed, &config)
expect.NoError(t, err)
expect.Equal(t, config.Format, FormatCombined)
diff --git a/internal/logging/accesslog/file_logger.go b/internal/logging/accesslog/file_logger.go
index 7a85662..36e791c 100644
--- a/internal/logging/accesslog/file_logger.go
+++ b/internal/logging/accesslog/file_logger.go
@@ -7,7 +7,7 @@ import (
"path/filepath"
"sync"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils"
)
@@ -35,20 +35,19 @@ func newFileIO(path string) (SupportRotate, error) {
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
return opened, nil
- } else {
- // cannot open as O_APPEND as we need Seek and WriteAt
- f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
- if err != nil {
- return nil, fmt.Errorf("access log open error: %w", err)
- }
- if _, err := f.Seek(0, io.SeekEnd); err != nil {
- return nil, fmt.Errorf("access log seek error: %w", err)
- }
- file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
- openedFiles[path] = file
- go file.closeOnZero()
}
+ // cannot open as O_APPEND as we need Seek and WriteAt
+ f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
+ if err != nil {
+ return nil, fmt.Errorf("access log open error: %w", err)
+ }
+ if _, err := f.Seek(0, io.SeekEnd); err != nil {
+ return nil, fmt.Errorf("access log seek error: %w", err)
+ }
+ file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
+ openedFiles[path] = file
+ go file.closeOnZero()
return file, nil
}
@@ -90,7 +89,7 @@ func (f *File) Close() error {
}
func (f *File) closeOnZero() {
- defer logging.Debug().
+ defer log.Debug().
Str("path", f.path).
Msg("access log closed")
diff --git a/internal/logging/logging.go b/internal/logging/logging.go
index 375d4dd..eae974e 100644
--- a/internal/logging/logging.go
+++ b/internal/logging/logging.go
@@ -1,4 +1,3 @@
-//nolint:zerologlint
package logging
import (
@@ -10,6 +9,8 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/utils/strutils"
+
+ zerologlog "github.com/rs/zerolog/log"
)
var (
@@ -61,22 +62,6 @@ func InitLogger(out ...io.Writer) {
log.SetOutput(writer)
log.SetPrefix("")
log.SetFlags(0)
+ zerolog.TimeFieldFormat = timeFmt
+ zerologlog.Logger = logger
}
-
-func DiscardLogger() { zerolog.SetGlobalLevel(zerolog.Disabled) }
-
-func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
-
-func GetLogger() *zerolog.Logger { return &logger }
-func With() zerolog.Context { return logger.With() }
-
-func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
-
-func Info() *zerolog.Event { return logger.Info() }
-func Warn() *zerolog.Event { return logger.Warn() }
-func Error() *zerolog.Event { return logger.Error() }
-func Err(err error) *zerolog.Event { return logger.Err(err) }
-func Debug() *zerolog.Event { return logger.Debug() }
-func Fatal() *zerolog.Event { return logger.Fatal() }
-func Panic() *zerolog.Event { return logger.Panic() }
-func Trace() *zerolog.Event { return logger.Trace() }
diff --git a/internal/maxmind/types/config.go b/internal/maxmind/types/config.go
index 298f3a5..a2bb526 100644
--- a/internal/maxmind/types/config.go
+++ b/internal/maxmind/types/config.go
@@ -2,8 +2,8 @@ package maxmind
import (
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
)
type (
@@ -28,6 +28,6 @@ func (cfg *Config) Validate() gperr.Error {
}
func (cfg *Config) Logger() *zerolog.Logger {
- l := logging.With().Str("database", string(cfg.Database)).Logger()
+ l := log.With().Str("database", string(cfg.Database)).Logger()
return &l
}
diff --git a/internal/metrics/period/poller.go b/internal/metrics/period/poller.go
index 9422f7a..b068c8f 100644
--- a/internal/metrics/period/poller.go
+++ b/internal/metrics/period/poller.go
@@ -10,8 +10,8 @@ import (
"sync"
"time"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/atomic"
)
@@ -47,7 +47,7 @@ var initDataDirOnce sync.Once
func initDataDir() {
if err := os.MkdirAll(saveBaseDir, 0o755); err != nil {
- logging.Error().Err(err).Msg("failed to create metrics data directory")
+ log.Error().Err(err).Msg("failed to create metrics data directory")
}
}
@@ -65,7 +65,7 @@ func NewPoller[T any, AggregateT json.Marshaler](
}
func (p *Poller[T, AggregateT]) savePath() string {
- return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name))
+ return filepath.Join(saveBaseDir, p.name+".json")
}
func (p *Poller[T, AggregateT]) load() error {
@@ -135,13 +135,14 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
func (p *Poller[T, AggregateT]) Start() {
t := task.RootTask("poller." + p.name)
+ l := log.With().Str("name", p.name).Logger()
err := p.load()
if err != nil {
if !os.IsNotExist(err) {
- logging.Error().Err(err).Msgf("failed to load last metrics data for %s", p.name)
+ l.Err(err).Msg("failed to load last metrics data")
}
} else {
- logging.Debug().Msgf("Loaded last metrics data for %s, %d entries", p.name, p.period.Total())
+ l.Debug().Int("entries", p.period.Total()).Msgf("Loaded last metrics data")
}
go func() {
@@ -154,11 +155,13 @@ func (p *Poller[T, AggregateT]) Start() {
gatherErrsTicker.Stop()
saveTicker.Stop()
- p.save()
+ if err := p.save(); err != nil {
+ l.Err(err).Msg("failed to save metrics data")
+ }
t.Finish(nil)
}()
- logging.Debug().Msgf("Starting poller %s with interval %s", p.name, pollInterval)
+ l.Debug().Dur("interval", pollInterval).Msg("Starting poller")
p.pollWithTimeout(t.Context())
@@ -176,7 +179,7 @@ func (p *Poller[T, AggregateT]) Start() {
case <-gatherErrsTicker.C:
errs, ok := p.gatherErrs()
if ok {
- logging.Error().Msg(errs)
+ log.Error().Msg(errs)
}
p.clearErrs()
}
diff --git a/internal/metrics/systeminfo/system_info.go b/internal/metrics/systeminfo/system_info.go
index c76d165..05e317e 100644
--- a/internal/metrics/systeminfo/system_info.go
+++ b/internal/metrics/systeminfo/system_info.go
@@ -8,6 +8,7 @@ import (
"syscall"
"time"
+ "github.com/rs/zerolog/log"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
@@ -16,7 +17,6 @@ import (
"github.com/shirou/gopsutil/v4/warning"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics/period"
)
@@ -130,7 +130,7 @@ func getSystemInfo(ctx context.Context, lastResult *SystemInfo) (*SystemInfo, er
}
})
if allWarnings.HasError() {
- logging.Warn().Msg(allWarnings.String())
+ log.Warn().Msg(allWarnings.String())
}
if allErrors.HasError() {
return nil, allErrors.Error()
@@ -195,7 +195,7 @@ func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInf
if len(s.Disks) == 0 {
return errs.Error()
}
- logging.Warn().Msg(errs.String())
+ log.Warn().Msg(errs.String())
}
return nil
}
diff --git a/internal/net/gphttp/body.go b/internal/net/gphttp/body.go
index b2d0173..be0bf51 100644
--- a/internal/net/gphttp/body.go
+++ b/internal/net/gphttp/body.go
@@ -6,7 +6,7 @@ import (
"errors"
"net/http"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
)
func WriteBody(w http.ResponseWriter, body []byte) {
@@ -14,9 +14,9 @@ func WriteBody(w http.ResponseWriter, body []byte) {
switch {
case errors.Is(err, http.ErrHandlerTimeout),
errors.Is(err, context.DeadlineExceeded):
- logging.Err(err).Msg("timeout writing body")
+ log.Err(err).Msg("timeout writing body")
default:
- logging.Err(err).Msg("failed to write body")
+ log.Err(err).Msg("failed to write body")
}
}
}
diff --git a/internal/net/gphttp/gpwebsocket/utils.go b/internal/net/gphttp/gpwebsocket/utils.go
index 0a0d229..c5bd3cc 100644
--- a/internal/net/gphttp/gpwebsocket/utils.go
+++ b/internal/net/gphttp/gpwebsocket/utils.go
@@ -9,11 +9,11 @@ import (
"time"
"github.com/gorilla/websocket"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
)
func warnNoMatchDomains() {
- logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
+ log.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
}
var warnNoMatchDomainOnce sync.Once
diff --git a/internal/net/gphttp/loadbalancer/loadbalancer.go b/internal/net/gphttp/loadbalancer/loadbalancer.go
index d36bf98..90b795c 100644
--- a/internal/net/gphttp/loadbalancer/loadbalancer.go
+++ b/internal/net/gphttp/loadbalancer/loadbalancer.go
@@ -7,8 +7,8 @@ import (
"time"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/task"
@@ -47,7 +47,7 @@ func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Config: new(Config),
pool: pool.New[Server]("loadbalancer." + cfg.Link),
- l: logging.With().Str("name", cfg.Link).Logger(),
+ l: log.With().Str("name", cfg.Link).Logger(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb
diff --git a/internal/net/gphttp/loadbalancer/types/weight.go b/internal/net/gphttp/loadbalancer/types/weight.go
index 2bf7d84..2339a27 100644
--- a/internal/net/gphttp/loadbalancer/types/weight.go
+++ b/internal/net/gphttp/loadbalancer/types/weight.go
@@ -1,3 +1,3 @@
package types
-type Weight uint16
+type Weight int
diff --git a/internal/net/gphttp/logging.go b/internal/net/gphttp/logging.go
index cfb67f0..9db2f35 100644
--- a/internal/net/gphttp/logging.go
+++ b/internal/net/gphttp/logging.go
@@ -4,14 +4,14 @@ import (
"net/http"
"github.com/rs/zerolog"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
- return logging.WithLevel(level).
- Str("remote", r.RemoteAddr).
- Str("host", r.Host).
- Str("uri", r.Method+" "+r.RequestURI)
+ return log.WithLevel(level). //nolint:zerologlint
+ Str("remote", r.RemoteAddr).
+ Str("host", r.Host).
+ Str("uri", r.Method+" "+r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
diff --git a/internal/net/gphttp/middleware/captcha/middleware.go b/internal/net/gphttp/middleware/captcha/middleware.go
index 91da59c..c4769ab 100644
--- a/internal/net/gphttp/middleware/captcha/middleware.go
+++ b/internal/net/gphttp/middleware/captcha/middleware.go
@@ -4,8 +4,8 @@ import (
"net/http"
"text/template"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/auth"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
_ "embed"
@@ -55,7 +55,7 @@ func PreRequest(p Provider, w http.ResponseWriter, r *http.Request) (proceed boo
"FormHTML": p.FormHTML(),
})
if err != nil {
- logging.Error().Err(err).Msg("failed to execute captcha page")
+ log.Error().Err(err).Msg("failed to execute captcha page")
}
return false
}
diff --git a/internal/net/gphttp/middleware/cidr_whitelist.go b/internal/net/gphttp/middleware/cidr_whitelist.go
index 6b9271f..291b45f 100644
--- a/internal/net/gphttp/middleware/cidr_whitelist.go
+++ b/internal/net/gphttp/middleware/cidr_whitelist.go
@@ -7,7 +7,7 @@ import (
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@@ -34,7 +34,7 @@ var (
)
func init() {
- utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
+ serialization.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
statusCode := fl.Field().Int()
return gphttp.IsStatusCodeValid(int(statusCode))
})
@@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
- for _, cidr := range wl.CIDRWhitelistOpts.Allow {
+ for _, cidr := range wl.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
@@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
- wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
+ wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
}
}
if !allow {
diff --git a/internal/net/gphttp/middleware/cidr_whitelist_test.go b/internal/net/gphttp/middleware/cidr_whitelist_test.go
index a8c7cee..0d73f29 100644
--- a/internal/net/gphttp/middleware/cidr_whitelist_test.go
+++ b/internal/net/gphttp/middleware/cidr_whitelist_test.go
@@ -8,7 +8,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
- ExpectError(t, utils.ErrValidationError, err)
+ ExpectError(t, serialization.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
@@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"status_code": 600,
"message": testMessage,
})
- ExpectError(t, utils.ErrValidationError, err)
+ ExpectError(t, serialization.ErrValidationError, err)
})
}
diff --git a/internal/net/gphttp/middleware/cloudflare_real_ip.go b/internal/net/gphttp/middleware/cloudflare_real_ip.go
index 19314d6..59c79d3 100644
--- a/internal/net/gphttp/middleware/cloudflare_real_ip.go
+++ b/internal/net/gphttp/middleware/cloudflare_real_ip.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "context"
"errors"
"fmt"
"io"
@@ -9,8 +10,8 @@ import (
"sync"
"time"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils"
@@ -89,21 +90,29 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
)
if err != nil {
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
- logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
+ log.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
if len(cfCIDRs) == 0 {
- logging.Warn().Msg("cloudflare CIDR range is empty")
+ log.Warn().Msg("cloudflare CIDR range is empty")
}
}
cfCIDRsLastUpdate.Store(time.Now())
- logging.Info().Msg("cloudflare CIDR range updated")
+ log.Info().Msg("cloudflare CIDR range updated")
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
- resp, err := http.Get(endpoint)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
+ if err != nil {
+ return err
+ }
+
+ resp, err := http.DefaultClient.Do(req) //nolint:gosec
if err != nil {
return err
}
diff --git a/internal/net/gphttp/middleware/custom_error_page.go b/internal/net/gphttp/middleware/custom_error_page.go
index 76110be..baef63f 100644
--- a/internal/net/gphttp/middleware/custom_error_page.go
+++ b/internal/net/gphttp/middleware/custom_error_page.go
@@ -8,7 +8,7 @@ import (
"strconv"
"strings"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
@@ -32,7 +32,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
- logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
+ log.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
@@ -40,7 +40,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else {
- logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
+ log.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
@@ -56,7 +56,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
filename := path[len(StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
- logging.Error().Msg("unable to load resource " + filename)
+ log.Error().Msg("unable to load resource " + filename)
return false
}
ext := filepath.Ext(filename)
@@ -68,10 +68,10 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
case ".css":
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default:
- logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
+ log.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
- logging.Err(err).Msg("unable to write resource " + filename)
+ log.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true
diff --git a/internal/net/gphttp/middleware/errorpage/error_page.go b/internal/net/gphttp/middleware/errorpage/error_page.go
index a1acec6..e63b65a 100644
--- a/internal/net/gphttp/middleware/errorpage/error_page.go
+++ b/internal/net/gphttp/middleware/errorpage/error_page.go
@@ -6,9 +6,9 @@ import (
"path"
"sync"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@@ -48,7 +48,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
- logging.Err(err).Msg("failed to list error page resources")
+ log.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
@@ -57,11 +57,11 @@ func loadContent() {
}
content, err := os.ReadFile(file)
if err != nil {
- logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
+ log.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
- logging.Info().Msgf("error page resource %s loaded", file)
+ log.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
@@ -83,9 +83,9 @@ func watchDir() {
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
- logging.Warn().Msgf("error page resource %s deleted", filename)
+ log.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
- logging.Warn().Msgf("error page resource %s deleted", filename)
+ log.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}
diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go
index 213aa10..ab7227f 100644
--- a/internal/net/gphttp/middleware/middleware.go
+++ b/internal/net/gphttp/middleware/middleware.go
@@ -8,11 +8,11 @@ import (
"sort"
"strings"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type (
@@ -87,7 +87,7 @@ func NewMiddleware[ImplType any]() *Middleware {
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace()
- logging.Trace().Msgf("middleware %s enabled trace", m.name)
+ log.Trace().Msgf("middleware %s enabled trace", m.name)
}
}
@@ -118,14 +118,14 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
"priority": optsRaw["priority"],
"bypass": optsRaw["bypass"],
}
- if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
+ if err := serialization.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
return err
}
optsRaw = maps.Clone(optsRaw)
for k := range commonOpts {
delete(optsRaw, k)
}
- return utils.MapUnmarshalValidate(optsRaw, m.impl)
+ return serialization.MapUnmarshalValidate(optsRaw, m.impl)
}
func (m *Middleware) finalize() error {
diff --git a/internal/net/gphttp/middleware/middlewares.go b/internal/net/gphttp/middleware/middlewares.go
index ca2a85f..8ae4ee9 100644
--- a/internal/net/gphttp/middleware/middlewares.go
+++ b/internal/net/gphttp/middleware/middlewares.go
@@ -3,9 +3,9 @@ package middleware
import (
"path"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -59,7 +59,7 @@ func LoadComposeFiles() {
errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
- logging.Err(err).Msg("failed to list middleware definitions")
+ log.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
@@ -75,7 +75,7 @@ func LoadComposeFiles() {
continue
}
allMiddlewares[name] = m
- logging.Info().
+ log.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
@@ -94,7 +94,7 @@ func LoadComposeFiles() {
continue
}
allMiddlewares[name] = m
- logging.Info().
+ log.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
diff --git a/internal/net/gphttp/reverseproxy/reverse_proxy_mod.go b/internal/net/gphttp/reverseproxy/reverse_proxy.go
similarity index 96%
rename from internal/net/gphttp/reverseproxy/reverse_proxy_mod.go
rename to internal/net/gphttp/reverseproxy/reverse_proxy.go
index a28d871..59faafe 100644
--- a/internal/net/gphttp/reverseproxy/reverse_proxy_mod.go
+++ b/internal/net/gphttp/reverseproxy/reverse_proxy.go
@@ -25,7 +25,7 @@ import (
"sync"
"github.com/rs/zerolog"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
@@ -138,7 +138,7 @@ func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper
panic("nil transport")
}
rp := &ReverseProxy{
- Logger: logging.With().Str("name", name).Logger(),
+ Logger: log.With().Str("name", name).Logger(),
Transport: transport,
TargetName: name,
TargetURL: target,
@@ -173,17 +173,17 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF),
errors.Is(err, context.DeadlineExceeded):
- logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
+ log.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
default:
var recordErr tls.RecordHeaderError
if errors.As(err, &recordErr) {
- logging.Error().
+ log.Error().
Str("url", reqURL).
Msgf(`scheme was likely misconfigured as https,
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
- logging.Err(err).Msg("underlying error")
+ log.Err(err).Msg("underlying error")
} else {
- logging.Err(err).Str("url", reqURL).Msg("http proxy error")
+ log.Err(err).Str("url", reqURL).Msg("http proxy error")
}
}
@@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
- /* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
@@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return nil
},
}
- outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
+ outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck
res, err := transport.RoundTrip(outreq)
@@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
- /* trunk-ignore(golangci-lint/errorlint) */
+ //nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
- /* trunk-ignore(golangci-lint/errorlint) */
+ //nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
- /* trunk-ignore(golangci-lint/errcheck) */
+ //nolint:errcheck
bdp.Start()
}
diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go
index c42afec..bf2e7f6 100644
--- a/internal/net/gphttp/server/server.go
+++ b/internal/net/gphttp/server/server.go
@@ -9,14 +9,14 @@ import (
"github.com/quic-go/quic-go/http3"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/acl"
"github.com/yusing/go-proxy/internal/common"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type CertProvider interface {
- GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
+ GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type Server struct {
@@ -53,7 +53,7 @@ func StartServer(parent task.Parent, opt Options) (s *Server) {
func NewServer(opt Options) (s *Server) {
var httpSer, httpsSer *http.Server
- logger := logging.With().Str("server", opt.Name).Logger()
+ logger := log.With().Str("server", opt.Name).Logger()
certAvailable := false
if opt.CertProvider != nil {
diff --git a/internal/notif/config.go b/internal/notif/config.go
index 8bc396e..7f91c3d 100644
--- a/internal/notif/config.go
+++ b/internal/notif/config.go
@@ -2,7 +2,7 @@ package notif
import (
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type NotificationConfig struct {
@@ -46,5 +46,5 @@ func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err gperr.Error)
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
}
- return utils.MapUnmarshalValidate(m, cfg.Provider)
+ return serialization.MapUnmarshalValidate(m, cfg.Provider)
}
diff --git a/internal/notif/config_test.go b/internal/notif/config_test.go
index a7afaac..68d3897 100644
--- a/internal/notif/config_test.go
+++ b/internal/notif/config_test.go
@@ -4,7 +4,7 @@ import (
"net/http"
"testing"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -181,7 +181,7 @@ func TestNotificationConfig(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
var cfg NotificationConfig
provider := tt.cfg["provider"]
- err := utils.MapUnmarshalValidate(tt.cfg, &cfg)
+ err := serialization.MapUnmarshalValidate(tt.cfg, &cfg)
if tt.wantErr {
ExpectHasError(t, err)
} else {
diff --git a/internal/notif/providers.go b/internal/notif/providers.go
index 84c3c41..be7ba9f 100644
--- a/internal/notif/providers.go
+++ b/internal/notif/providers.go
@@ -8,14 +8,14 @@ import (
"net/http"
"time"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
)
type (
Provider interface {
- utils.CustomValidator
+ serialization.CustomValidator
GetName() string
GetURL() string
@@ -73,7 +73,7 @@ func (msg *LogMessage) notify(ctx context.Context, provider Provider) error {
switch resp.StatusCode {
case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent:
body, _ := io.ReadAll(resp.Body)
- logging.Debug().
+ log.Debug().
Str("provider", provider.GetName()).
Str("url", provider.GetURL()).
Str("status", resp.Status).
diff --git a/internal/route/provider/docker.go b/internal/route/provider/docker.go
index 36c1811..340b628 100755
--- a/internal/route/provider/docker.go
+++ b/internal/route/provider/docker.go
@@ -7,12 +7,12 @@ import (
"github.com/docker/docker/client"
"github.com/goccy/go-yaml"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/route"
- U "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/watcher"
)
@@ -36,7 +36,7 @@ func DockerProviderImpl(name, dockerHost string) ProviderImpl {
return &DockerProvider{
name,
dockerHost,
- logging.With().Str("type", "docker").Str("name", name).Logger(),
+ log.With().Str("type", "docker").Str("name", name).Logger(),
}
}
@@ -106,7 +106,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
// Always non-nil.
func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) {
if !container.IsExplicit && p.IsExplicitOnly() {
- return nil, nil
+ return make(route.Routes, 0), nil
}
routes := make(route.Routes, len(container.Aliases))
@@ -180,7 +180,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
}
// deserialize map into entry object
- err := U.MapUnmarshalValidate(entryMap, r)
+ err := serialization.MapUnmarshalValidate(entryMap, r)
if err != nil {
errs.Add(err.Subject(alias))
} else {
@@ -189,7 +189,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
}
if wildcardProps != nil {
for _, re := range routes {
- if err := U.MapUnmarshalValidate(wildcardProps, re); err != nil {
+ if err := serialization.MapUnmarshalValidate(wildcardProps, re); err != nil {
errs.Add(err.Subject(docker.WildcardAlias))
break
}
diff --git a/internal/route/provider/file.go b/internal/route/provider/file.go
index 6c23a60..06921b7 100644
--- a/internal/route/provider/file.go
+++ b/internal/route/provider/file.go
@@ -6,11 +6,11 @@ import (
"strings"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/route"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
W "github.com/yusing/go-proxy/internal/watcher"
)
@@ -24,7 +24,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
- l: logging.With().Str("type", "file").Str("name", filename).Logger(),
+ l: log.With().Str("type", "file").Str("name", filename).Logger(),
}
_, err := os.Stat(impl.path)
if err != nil {
@@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
}
func validate(data []byte) (routes route.Routes, err gperr.Error) {
- err = utils.UnmarshalValidateYAML(data, &routes)
+ err = serialization.UnmarshalValidateYAML(data, &routes)
return
}
diff --git a/internal/route/route.go b/internal/route/route.go
index 9ff3eb6..62c41f5 100644
--- a/internal/route/route.go
+++ b/internal/route/route.go
@@ -7,12 +7,12 @@ import (
"time"
"github.com/docker/docker/api/types/container"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/homepage"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
- "github.com/yusing/go-proxy/internal/logging"
netutils "github.com/yusing/go-proxy/internal/net"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxmox"
@@ -116,7 +116,7 @@ func (r *Route) Validate() gperr.Error {
Subject(containerName)
}
- l := logging.With().Str("container", containerName).Logger()
+ l := log.With().Str("container", containerName).Logger()
l.Info().Msg("checking if container is running")
running, err := node.LXCIsRunning(ctx, vmid)
diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go
index aa8e8d1..0601e3a 100644
--- a/internal/route/rules/rules_test.go
+++ b/internal/route/rules/rules_test.go
@@ -3,7 +3,7 @@ package rules
import (
"testing"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -28,7 +28,7 @@ func TestParseRule(t *testing.T) {
var rules struct {
Rules Rules
}
- err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules)
+ err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
ExpectNoError(t, err)
ExpectEqual(t, len(rules.Rules), len(test))
ExpectEqual(t, rules.Rules[0].Name, "test")
diff --git a/internal/route/stream.go b/internal/route/stream.go
index bd87cfe..9c76de0 100755
--- a/internal/route/stream.go
+++ b/internal/route/stream.go
@@ -5,9 +5,9 @@ import (
"errors"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/idlewatcher"
- "github.com/yusing/go-proxy/internal/logging"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
@@ -32,7 +32,7 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
// TODO: support non-coherent scheme
return &StreamRoute{
Route: base,
- l: logging.With().
+ l: log.With().
Str("type", string(base.Scheme)).
Str("name", base.Name()).
Logger(),
diff --git a/internal/route/types/http_config_test.go b/internal/route/types/http_config_test.go
index c6ea8cf..8f41760 100644
--- a/internal/route/types/http_config_test.go
+++ b/internal/route/types/http_config_test.go
@@ -6,7 +6,7 @@ import (
. "github.com/yusing/go-proxy/internal/route"
route "github.com/yusing/go-proxy/internal/route/types"
- "github.com/yusing/go-proxy/internal/utils"
+ "github.com/yusing/go-proxy/internal/serialization"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
@@ -40,7 +40,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
cfg := Route{}
tt.input["host"] = "internal"
- err := utils.MapUnmarshalValidate(tt.input, &cfg)
+ err := serialization.MapUnmarshalValidate(tt.input, &cfg)
if err != nil {
expect.NoError(t, err)
}
diff --git a/internal/route/udp_forwarder.go b/internal/route/udp_forwarder.go
index 62149af..581e273 100644
--- a/internal/route/udp_forwarder.go
+++ b/internal/route/udp_forwarder.go
@@ -6,8 +6,8 @@ import (
"net"
"sync"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@@ -88,7 +88,7 @@ func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) {
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
if err == nil {
- logging.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
+ log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
}
return
}
@@ -102,7 +102,7 @@ func (conn *UDPConn) read() (err error) {
conn.buf.oobn = 0
}
if err == nil {
- logging.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
+ log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
}
return
}
@@ -110,7 +110,7 @@ func (conn *UDPConn) read() (err error) {
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
if err == nil {
- logging.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
+ log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
}
return
}
@@ -120,12 +120,12 @@ func (conn *UDPConn) write() (err error) {
case *net.UDPConn:
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
if err == nil {
- logging.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
+ log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
}
default:
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
if err == nil {
- logging.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
+ log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
}
}
diff --git a/internal/utils/serialization.go b/internal/serialization/serialization.go
similarity index 97%
rename from internal/utils/serialization.go
rename to internal/serialization/serialization.go
index e540d23..54b9631 100644
--- a/internal/utils/serialization.go
+++ b/internal/serialization/serialization.go
@@ -1,4 +1,4 @@
-package utils
+package serialization
import (
"encoding/json"
@@ -12,7 +12,9 @@ import (
"github.com/go-playground/validator/v10"
"github.com/goccy/go-yaml"
+ "github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/go-proxy/internal/gperr"
+ "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -40,14 +42,14 @@ var (
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
-var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
+var defaultValues = xsync.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
t := reflect.TypeFor[T]()
if t.Kind() == reflect.Ptr {
panic("pointer of pointer")
}
- if defaultValues.Has(t) {
+ if _, ok := defaultValues.Load(t); ok {
panic("default value for " + t.String() + " already registered")
}
defaultValues.Store(t, func() any { return factory() })
@@ -259,7 +261,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool)
errs.Add(err.Subject(k))
}
} else {
- errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping))))
+ errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(utils.NearestField(k, mapping))))
}
}
if hasValidateTag && checkValidateTag {
diff --git a/internal/utils/serialization_test.go b/internal/serialization/serialization_test.go
similarity index 99%
rename from internal/utils/serialization_test.go
rename to internal/serialization/serialization_test.go
index 6a08779..d067f81 100644
--- a/internal/utils/serialization_test.go
+++ b/internal/serialization/serialization_test.go
@@ -1,4 +1,4 @@
-package utils
+package serialization
import (
"reflect"
diff --git a/internal/utils/validation.go b/internal/serialization/validation.go
similarity index 95%
rename from internal/utils/validation.go
rename to internal/serialization/validation.go
index 490c348..0b222f2 100644
--- a/internal/utils/validation.go
+++ b/internal/serialization/validation.go
@@ -1,4 +1,4 @@
-package utils
+package serialization
import (
"github.com/go-playground/validator/v10"
diff --git a/internal/task/task.go b/internal/task/task.go
index 2421a13..9faa7a6 100644
--- a/internal/task/task.go
+++ b/internal/task/task.go
@@ -7,9 +7,9 @@ import (
"sync/atomic"
"time"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@@ -116,14 +116,14 @@ func (t *Task) Finish(reason any) {
func (t *Task) finish(reason any) {
t.cancel(fmtCause(reason))
if !waitWithTimeout(t.childrenDone) {
- logging.Debug().
+ log.Debug().
Str("task", t.name).
Strs("subtasks", t.listChildren()).
Msg("Timeout waiting for subtasks to finish")
}
go t.runCallbacks()
if !waitWithTimeout(t.callbacksDone) {
- logging.Debug().
+ log.Debug().
Str("task", t.name).
Strs("callbacks", t.listCallbacks()).
Msg("Timeout waiting for callbacks to finish")
@@ -134,7 +134,7 @@ func (t *Task) finish(reason any) {
}
t.parent.subChildCount()
allTasks.Remove(t)
- logging.Trace().Msg("task " + t.name + " finished")
+ log.Trace().Msg("task " + t.name + " finished")
}
// Subtask returns a new subtask with the given name, derived from the parent's context.
@@ -166,7 +166,7 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
}()
}
- logging.Trace().Msg("task " + child.name + " started")
+ log.Trace().Msg("task " + child.name + " started")
return child
}
@@ -189,7 +189,7 @@ func (t *Task) MarshalText() ([]byte, error) {
func (t *Task) invokeWithRecover(fn func(), caller string) {
defer func() {
if err := recover(); err != nil {
- logging.Error().
+ log.Error().
Interface("err", err).
Msg("panic in task " + t.name + "." + caller)
if common.IsDebug {
diff --git a/internal/task/utils.go b/internal/task/utils.go
index da0bc36..e75b3be 100644
--- a/internal/task/utils.go
+++ b/internal/task/utils.go
@@ -9,7 +9,7 @@ import (
"syscall"
"time"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@@ -68,10 +68,10 @@ func GracefulShutdown(timeout time.Duration) (err error) {
case <-after:
b, err := json.Marshal(DebugTaskList())
if err != nil {
- logging.Warn().Err(err).Msg("failed to marshal tasks")
+ log.Warn().Err(err).Msg("failed to marshal tasks")
return context.DeadlineExceeded
}
- logging.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
+ log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
return context.DeadlineExceeded
}
}
@@ -87,6 +87,6 @@ func WaitExit(shutdownTimeout int) {
<-sig
// gracefully shutdown
- logging.Info().Msg("shutting down")
+ log.Info().Msg("shutting down")
_ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout))
}
diff --git a/internal/utils/go.mod b/internal/utils/go.mod
new file mode 100644
index 0000000..6ccf74c
--- /dev/null
+++ b/internal/utils/go.mod
@@ -0,0 +1,21 @@
+module github.com/yusing/go-proxy/internal/utils
+
+go 1.24.3
+
+require (
+ github.com/goccy/go-yaml v1.17.1
+ github.com/puzpuzpuz/xsync/v4 v4.1.0
+ github.com/rs/zerolog v1.34.0
+ github.com/stretchr/testify v1.10.0
+ go.uber.org/atomic v1.11.0
+ golang.org/x/text v0.25.0
+)
+
+require (
+ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
+ golang.org/x/sys v0.33.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/internal/utils/go.sum b/internal/utils/go.sum
new file mode 100644
index 0000000..7477df6
--- /dev/null
+++ b/internal/utils/go.sum
@@ -0,0 +1,36 @@
+github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
+github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY=
+github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
+github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
+github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/puzpuzpuz/xsync/v4 v4.1.0 h1:x9eHRl4QhZFIPJ17yl4KKW9xLyVWbb3/Yq4SXpjF71U=
+github.com/puzpuzpuz/xsync/v4 v4.1.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
+github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
+github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
+github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
+go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
+golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
+golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/utils/io.go b/internal/utils/io.go
index 5d8488a..33881fc 100644
--- a/internal/utils/io.go
+++ b/internal/utils/io.go
@@ -8,7 +8,6 @@ import (
"sync"
"syscall"
- "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/synk"
)
@@ -91,20 +90,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
}
}
-func (p BidirectionalPipe) Start() gperr.Error {
+func (p BidirectionalPipe) Start() error {
var wg sync.WaitGroup
wg.Add(2)
- b := gperr.NewBuilder("bidirectional pipe error")
+ var srcErr, dstErr error
go func() {
- b.Add(p.pSrcDst.Start())
+ srcErr = p.pSrcDst.Start()
wg.Done()
}()
go func() {
- b.Add(p.pDstSrc.Start())
+ dstErr = p.pDstSrc.Start()
wg.Done()
}()
wg.Wait()
- return b.Error()
+ return errors.Join(srcErr, dstErr)
}
type httpFlusher interface {
@@ -143,30 +142,18 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
wCloser, wCanClose := dst.Writer.(io.Closer)
rCloser, rCanClose := src.Reader.(io.Closer)
if wCanClose || rCanClose {
- if src.ctx == dst.ctx {
- go func() {
- <-src.ctx.Done()
- if wCanClose {
- wCloser.Close()
- }
- if rCanClose {
- rCloser.Close()
- }
- }()
- } else {
- if wCloser != nil {
- go func() {
- <-src.ctx.Done()
- wCloser.Close()
- }()
+ go func() {
+ select {
+ case <-src.ctx.Done():
+ case <-dst.ctx.Done():
}
- if rCloser != nil {
- go func() {
- <-dst.ctx.Done()
- rCloser.Close()
- }()
+ if rCanClose {
+ defer rCloser.Close()
}
- }
+ if wCanClose {
+ defer wCloser.Close()
+ }
+ }()
}
flusher := getHTTPFlusher(dst.Writer)
canFlush := flusher != nil
diff --git a/internal/utils/pool/pool.go b/internal/utils/pool/pool.go
index 7abe4bc..84056e9 100644
--- a/internal/utils/pool/pool.go
+++ b/internal/utils/pool/pool.go
@@ -4,7 +4,7 @@ import (
"sort"
"github.com/puzpuzpuz/xsync/v4"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
)
type (
@@ -29,12 +29,12 @@ func (p Pool[T]) Name() string {
func (p Pool[T]) Add(obj T) {
p.checkExists(obj.Key())
p.m.Store(obj.Key(), obj)
- logging.Info().Msgf("%s: added %s", p.name, obj.Name())
+ log.Info().Msgf("%s: added %s", p.name, obj.Name())
}
func (p Pool[T]) Del(obj T) {
p.m.Delete(obj.Key())
- logging.Info().Msgf("%s: removed %s", p.name, obj.Name())
+ log.Info().Msgf("%s: removed %s", p.name, obj.Name())
}
func (p Pool[T]) Get(key string) (T, bool) {
diff --git a/internal/utils/pool/pool_debug.go b/internal/utils/pool/pool_debug.go
index 6ab84f6..bb444db 100644
--- a/internal/utils/pool/pool_debug.go
+++ b/internal/utils/pool/pool_debug.go
@@ -5,11 +5,11 @@ package pool
import (
"runtime/debug"
- "github.com/yusing/go-proxy/internal/logging"
+ "github.com/rs/zerolog/log"
)
func (p Pool[T]) checkExists(key string) {
if _, ok := p.m.Load(key); ok {
- logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
+ log.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
}
}
diff --git a/internal/utils/synk/pool.go b/internal/utils/synk/pool.go
index 11e282a..15cbd64 100644
--- a/internal/utils/synk/pool.go
+++ b/internal/utils/synk/pool.go
@@ -1,16 +1,35 @@
package synk
import (
- "os"
- "os/signal"
- "sync/atomic"
- "time"
-
- "github.com/yusing/go-proxy/internal/logging"
+ "runtime"
+ "unsafe"
)
+type weakBuf = unsafe.Pointer
+
+func makeWeak(b *[]byte) weakBuf {
+ ptr := runtime_registerWeakPointer(unsafe.Pointer(b))
+ runtime.KeepAlive(ptr)
+ addCleanup(b, addGCed, cap(*b))
+ return weakBuf(ptr)
+}
+
+func getBufFromWeak(w weakBuf) []byte {
+ ptr := (*[]byte)(runtime_makeStrongFromWeak(w))
+ if ptr == nil {
+ return nil
+ }
+ return *ptr
+}
+
+//go:linkname runtime_registerWeakPointer weak.runtime_registerWeakPointer
+func runtime_registerWeakPointer(unsafe.Pointer) unsafe.Pointer
+
+//go:linkname runtime_makeStrongFromWeak weak.runtime_makeStrongFromWeak
+func runtime_makeStrongFromWeak(unsafe.Pointer) unsafe.Pointer
+
type BytesPool struct {
- pool chan []byte
+ pool chan weakBuf
initSize int
}
@@ -22,19 +41,15 @@ const (
const (
InPoolLimit = 32 * mb
- DefaultInitBytes = 32 * kb
- PoolThreshold = 64 * kb
+ DefaultInitBytes = 4 * kb
+ PoolThreshold = 256 * kb
DropThresholdHigh = 4 * mb
PoolSize = InPoolLimit / PoolThreshold
-
- CleanupInterval = 5 * time.Second
- MaxDropsPerCycle = 10
- MaxChecksPerCycle = 100
)
var bytesPool = &BytesPool{
- pool: make(chan []byte, PoolSize),
+ pool: make(chan weakBuf, PoolSize),
initSize: DefaultInitBytes,
}
@@ -43,12 +58,18 @@ func NewBytesPool() *BytesPool {
}
func (p *BytesPool) Get() []byte {
- select {
- case b := <-p.pool:
- subInPoolSize(int64(cap(b)))
- return b
- default:
- return make([]byte, 0, p.initSize)
+ for {
+ select {
+ case bWeak := <-p.pool:
+ bPtr := getBufFromWeak(bWeak)
+ if bPtr == nil {
+ continue
+ }
+ addReused(cap(bPtr))
+ return bPtr
+ default:
+ return make([]byte, 0, p.initSize)
+ }
}
}
@@ -56,90 +77,43 @@ func (p *BytesPool) GetSized(size int) []byte {
if size <= PoolThreshold {
return make([]byte, size)
}
- select {
- case b := <-p.pool:
- if size <= cap(b) {
- subInPoolSize(int64(cap(b)))
- return b[:size]
- }
+ for {
select {
- case p.pool <- b:
- addInPoolSize(int64(cap(b)))
+ case bWeak := <-p.pool:
+ bPtr := getBufFromWeak(bWeak)
+ if bPtr == nil {
+ continue
+ }
+ capB := cap(bPtr)
+ if capB >= size {
+ addReused(capB)
+ return (bPtr)[:size]
+ }
+ select {
+ case p.pool <- bWeak:
+ default:
+ // just drop it
+ }
default:
}
- default:
+ return make([]byte, size)
}
- return make([]byte, size)
}
func (p *BytesPool) Put(b []byte) {
size := cap(b)
- if size > DropThresholdHigh || poolFull() {
+ if size <= PoolThreshold || size > DropThresholdHigh {
return
}
b = b[:0]
+ w := makeWeak(&b)
select {
- case p.pool <- b:
- addInPoolSize(int64(size))
- return
+ case p.pool <- w:
default:
// just drop it
}
}
-var inPoolSize int64
-
-func addInPoolSize(size int64) {
- atomic.AddInt64(&inPoolSize, size)
-}
-
-func subInPoolSize(size int64) {
- atomic.AddInt64(&inPoolSize, -size)
-}
-
func init() {
- // Periodically drop some buffers to prevent excessive memory usage
- go func() {
- sigCh := make(chan os.Signal, 1)
- signal.Notify(sigCh, os.Interrupt)
-
- cleanupTicker := time.NewTicker(CleanupInterval)
- defer cleanupTicker.Stop()
-
- for {
- select {
- case <-cleanupTicker.C:
- dropBuffers()
- case <-sigCh:
- return
- }
- }
- }()
-}
-
-func poolFull() bool {
- return atomic.LoadInt64(&inPoolSize) >= InPoolLimit
-}
-
-// dropBuffers removes excess buffers from the pool when it grows too large.
-func dropBuffers() {
- // Check if pool has more than a threshold of buffers
- count := 0
- droppedSize := 0
- checks := 0
- for count < MaxDropsPerCycle && checks < MaxChecksPerCycle && atomic.LoadInt64(&inPoolSize) > InPoolLimit*2/3 {
- select {
- case b := <-bytesPool.pool:
- n := cap(b)
- subInPoolSize(int64(n))
- droppedSize += n
- count++
- default:
- time.Sleep(10 * time.Millisecond)
- }
- checks++
- }
- if count > 0 {
- logging.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool")
- }
+ initPoolStats()
}
diff --git a/internal/utils/synk/pool_bench_test.go b/internal/utils/synk/pool_bench_test.go
index 6bb6e02..7339a23 100644
--- a/internal/utils/synk/pool_bench_test.go
+++ b/internal/utils/synk/pool_bench_test.go
@@ -4,7 +4,7 @@ import (
"testing"
)
-var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024}
+var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024}
func BenchmarkBytesPool_GetSmall(b *testing.B) {
for b.Loop() {
diff --git a/internal/utils/synk/pool_debug.go b/internal/utils/synk/pool_debug.go
new file mode 100644
index 0000000..0a7aa00
--- /dev/null
+++ b/internal/utils/synk/pool_debug.go
@@ -0,0 +1,55 @@
+//go:build !production
+
+package synk
+
+import (
+ "os"
+ "os/signal"
+ "runtime"
+ "sync/atomic"
+ "time"
+
+ "github.com/rs/zerolog/log"
+ "github.com/yusing/go-proxy/internal/utils/strutils"
+)
+
+var (
+ numReused, sizeReused uint64
+ numGCed, sizeGCed uint64
+)
+
+func addReused(size int) {
+ atomic.AddUint64(&numReused, 1)
+ atomic.AddUint64(&sizeReused, uint64(size))
+}
+
+func addGCed(size int) {
+ atomic.AddUint64(&numGCed, 1)
+ atomic.AddUint64(&sizeGCed, uint64(size))
+}
+
+var addCleanup = runtime.AddCleanup[[]byte, int]
+
+func initPoolStats() {
+ go func() {
+ statsTicker := time.NewTicker(5 * time.Second)
+ defer statsTicker.Stop()
+
+ sig := make(chan os.Signal, 1)
+ signal.Notify(sig, os.Interrupt)
+
+ for {
+ select {
+ case <-sig:
+ return
+ case <-statsTicker.C:
+ log.Info().
+ Uint64("numReused", atomic.LoadUint64(&numReused)).
+ Str("sizeReused", strutils.FormatByteSize(atomic.LoadUint64(&sizeReused))).
+ Uint64("numGCed", atomic.LoadUint64(&numGCed)).
+ Str("sizeGCed", strutils.FormatByteSize(atomic.LoadUint64(&sizeGCed))).
+ Msg("bytes pool stats")
+ }
+ }
+ }()
+}
diff --git a/internal/utils/synk/pool_prod.go b/internal/utils/synk/pool_prod.go
new file mode 100644
index 0000000..3cca12d
--- /dev/null
+++ b/internal/utils/synk/pool_prod.go
@@ -0,0 +1,8 @@
+//go:build production
+
+package synk
+
+func addReused(size int) {}
+func addGCed(size int) {}
+func initPoolStats() {}
+func addCleanup(ptr *[]byte, cleanup func(int), arg int) {}
diff --git a/internal/utils/testing/expect.go b/internal/utils/testing/expect.go
index a87ddaa..4455d7c 100644
--- a/internal/utils/testing/expect.go
+++ b/internal/utils/testing/expect.go
@@ -2,14 +2,16 @@ package expect
import (
"os"
+ "strings"
"testing"
"github.com/stretchr/testify/require"
- "github.com/yusing/go-proxy/internal/common"
)
+var isTest = strings.HasSuffix(os.Args[0], ".test")
+
func init() {
- if common.IsTest {
+ if isTest {
// force verbose output
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}
diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go
index 52f118c..13411c3 100644
--- a/internal/utils/testing/testing.go
+++ b/internal/utils/testing/testing.go
@@ -1,19 +1,11 @@
package expect
import (
- "os"
"testing"
"github.com/stretchr/testify/require"
- "github.com/yusing/go-proxy/internal/common"
)
-func init() {
- if common.IsTest {
- os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
- }
-}
-
func ExpectNoError(t *testing.T, err error) {
t.Helper()
require.NoError(t, err)
diff --git a/internal/utils/time_now.go b/internal/utils/time_now.go
index 8b5e155..00bcff3 100644
--- a/internal/utils/time_now.go
+++ b/internal/utils/time_now.go
@@ -3,7 +3,6 @@ package utils
import (
"time"
- "github.com/yusing/go-proxy/internal/task"
"go.uber.org/atomic"
)
@@ -38,8 +37,6 @@ func init() {
go func() {
for {
select {
- case <-task.RootContext().Done():
- return
case <-timeNowTicker.C:
shouldCallTimeNow.Store(true)
}
diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go
index 45cca6a..4bfee33 100644
--- a/internal/watcher/directory_watcher.go
+++ b/internal/watcher/directory_watcher.go
@@ -8,8 +8,8 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/events"
)
@@ -41,13 +41,13 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
//! subdirectories are not watched
w, err := fsnotify.NewWatcher()
if err != nil {
- logging.Panic().Err(err).Msg("unable to create fs watcher")
+ log.Panic().Err(err).Msg("unable to create fs watcher")
}
if err = w.Add(dirPath); err != nil {
- logging.Panic().Err(err).Msg("unable to create fs watcher")
+ log.Panic().Err(err).Msg("unable to create fs watcher")
}
helper := &DirWatcher{
- Logger: logging.With().
+ Logger: log.With().
Str("type", "dir").
Str("path", dirPath).
Logger(),
diff --git a/internal/watcher/docker_watcher.go b/internal/watcher/docker_watcher.go
index 09713c6..1a01f9e 100644
--- a/internal/watcher/docker_watcher.go
+++ b/internal/watcher/docker_watcher.go
@@ -8,9 +8,9 @@ import (
docker_events "github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/watcher/events"
)
@@ -81,7 +81,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
}()
cEventCh, cErrCh := client.Events(ctx, options)
- defer logging.Debug().Str("host", client.Address()).Msg("docker watcher closed")
+ defer log.Debug().Str("host", client.Address()).Msg("docker watcher closed")
for {
select {
case <-ctx.Done():
@@ -153,7 +153,7 @@ func checkConnection(ctx context.Context, client *docker.SharedClient) bool {
defer cancel()
err := client.CheckConnection(ctx)
if err != nil {
- logging.Debug().Err(err).Msg("docker watcher: connection failed")
+ log.Debug().Err(err).Msg("docker watcher: connection failed")
return false
}
return true
diff --git a/internal/watcher/health/monitor/monitor.go b/internal/watcher/health/monitor/monitor.go
index a6b029f..9e0ead7 100644
--- a/internal/watcher/health/monitor/monitor.go
+++ b/internal/watcher/health/monitor/monitor.go
@@ -7,9 +7,9 @@ import (
"net/url"
"time"
+ "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
- "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
@@ -48,7 +48,7 @@ func NewMonitor(r routes.Route) health.HealthMonCheck {
case routes.StreamRoute:
mon = NewRawHealthMonitor(&r.TargetURL().URL, r.HealthCheckConfig())
default:
- logging.Panic().Msgf("unexpected route type: %T", r)
+ log.Panic().Msgf("unexpected route type: %T", r)
}
}
if r.IsDocker() {
@@ -91,7 +91,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
mon.task = parent.Subtask("health_monitor", true)
go func() {
- logger := logging.With().Str("name", mon.service).Logger()
+ logger := log.With().Str("name", mon.service).Logger()
defer func() {
if mon.status.Load() != health.StatusError {
@@ -221,7 +221,7 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
}
func (mon *monitor) checkUpdateHealth() error {
- logger := logging.With().Str("name", mon.Name()).Logger()
+ logger := log.With().Str("name", mon.Name()).Logger()
result, err := mon.checkHealth()
var lastStatus health.Status
diff --git a/pkg/version.go b/pkg/version.go
index 03c42ec..c037f0e 100644
--- a/pkg/version.go
+++ b/pkg/version.go
@@ -18,7 +18,7 @@ func GetLastVersion() Version {
func GetVersionHTTPHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte(GetVersion().String()))
+ fmt.Fprint(w, GetVersion().String())
}
}
@@ -34,16 +34,16 @@ func init() {
// lastVersion = ParseVersion(lastVersionStr)
// }
// if err != nil && !os.IsNotExist(err) {
- // logging.Warn().Err(err).Msg("failed to read version file")
+ // log.Warn().Err(err).Msg("failed to read version file")
// return
// }
// if err := f.Truncate(0); err != nil {
- // logging.Warn().Err(err).Msg("failed to truncate version file")
+ // log.Warn().Err(err).Msg("failed to truncate version file")
// return
// }
// _, err = f.WriteString(version)
// if err != nil {
- // logging.Warn().Err(err).Msg("failed to save version file")
+ // log.Warn().Err(err).Msg("failed to save version file")
// return
// }
}
diff --git a/socket-proxy/go.mod b/socket-proxy/go.mod
index cd8a7de..5c83f0b 100644
--- a/socket-proxy/go.mod
+++ b/socket-proxy/go.mod
@@ -2,4 +2,19 @@ module github.com/yusing/go-proxy/socketproxy
go 1.24.3
-require github.com/gorilla/mux v1.8.1
+replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
+
+require (
+ github.com/gorilla/mux v1.8.1
+ github.com/yusing/go-proxy/internal/utils v0.0.0-00010101000000-000000000000
+ golang.org/x/net v0.40.0
+)
+
+require (
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/rs/zerolog v1.34.0 // indirect
+ go.uber.org/atomic v1.11.0 // indirect
+ golang.org/x/sys v0.33.0 // indirect
+ golang.org/x/text v0.25.0 // indirect
+)
diff --git a/socket-proxy/go.sum b/socket-proxy/go.sum
index 7128337..7967116 100644
--- a/socket-proxy/go.sum
+++ b/socket-proxy/go.sum
@@ -1,2 +1,34 @@
+github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
+github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
+github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
+github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
+github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
+go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
+golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
+golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
+golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
+golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/socket-proxy/pkg/handler.go b/socket-proxy/pkg/handler.go
index 51c2049..930bc9e 100644
--- a/socket-proxy/pkg/handler.go
+++ b/socket-proxy/pkg/handler.go
@@ -4,30 +4,33 @@ import (
"context"
"net"
"net/http"
- "net/http/httputil"
"strings"
"time"
"github.com/gorilla/mux"
+ "github.com/yusing/go-proxy/socketproxy/pkg/reverseproxy"
)
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
-func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
- return dialer.DialContext(ctx, "unix", DockerSocket)
+func dialDockerSocket(socket string) func(ctx context.Context, _, _ string) (net.Conn, error) {
+ return func(ctx context.Context, _, _ string) (net.Conn, error) {
+ return dialer.DialContext(ctx, "unix", socket)
+ }
}
var DockerSocketHandler = dockerSocketHandler
-func dockerSocketHandler() http.HandlerFunc {
- rp := &httputil.ReverseProxy{
+func dockerSocketHandler(socket string) http.HandlerFunc {
+ rp := &reverseproxy.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = "http"
req.URL.Host = "api.moby.localhost"
req.RequestURI = req.URL.String()
},
Transport: &http.Transport{
- DialContext: dialDockerSocket,
+ DialContext: dialDockerSocket(socket),
+ DisableCompression: true,
},
}
@@ -41,7 +44,7 @@ func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
func NewHandler() http.Handler {
r := mux.NewRouter()
- socketHandler := DockerSocketHandler()
+ socketHandler := DockerSocketHandler(DockerSocket)
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
diff --git a/socket-proxy/pkg/handler_test.go b/socket-proxy/pkg/handler_test.go
index 59748e4..4c4cc0a 100644
--- a/socket-proxy/pkg/handler_test.go
+++ b/socket-proxy/pkg/handler_test.go
@@ -9,7 +9,7 @@ import (
. "github.com/yusing/go-proxy/socketproxy/pkg"
)
-func mockDockerSocketHandler() http.HandlerFunc {
+func mockDockerSocketHandler(_ string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("mock docker response"))
diff --git a/socket-proxy/pkg/reverseproxy/reverse_proxy.go b/socket-proxy/pkg/reverseproxy/reverse_proxy.go
new file mode 100644
index 0000000..b2abaff
--- /dev/null
+++ b/socket-proxy/pkg/reverseproxy/reverse_proxy.go
@@ -0,0 +1,367 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+// License URL: https://cs.opensource.google/go/go/+/master:LICENSE
+
+// HTTP reverse proxy handler
+
+package reverseproxy
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptrace"
+ "net/textproto"
+ "strings"
+ "sync"
+
+ "github.com/yusing/go-proxy/internal/utils"
+ "golang.org/x/net/http/httpguts"
+)
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+//
+// 1xx responses are forwarded to the client if the underlying
+// transport supports ClientTrace.Got1xxResponse.
+type ReverseProxy struct {
+ // Director is a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ // Director must not access the provided Request
+ // after returning.
+ //
+ // By default, the X-Forwarded-For header is set to the
+ // value of the client IP address. If an X-Forwarded-For
+ // header already exists, the client IP is appended to the
+ // existing values. As a special case, if the header
+ // exists in the Request.Header map but has a nil value
+ // (such as when set by the Director func), the X-Forwarded-For
+ // header is not modified.
+ //
+ // To prevent IP spoofing, be sure to delete any pre-existing
+ // X-Forwarded-For header coming from the client or
+ // an untrusted proxy.
+ //
+ // Hop-by-hop headers are removed from the request after
+ // Director returns, which can remove headers added by
+ // Director. Use a Rewrite function instead to ensure
+ // modifications to the request are preserved.
+ //
+ // Unparsable query parameters are removed from the outbound
+ // request if Request.Form is set after Director returns.
+ //
+ // At most one of Rewrite or Director may be set.
+ Director func(*http.Request)
+
+ // The transport used to perform proxy requests.
+ // If nil, http.DefaultTransport is used.
+ Transport http.RoundTripper
+
+ // ErrorLog specifies an optional logger for errors
+ // that occur when attempting to proxy the request.
+ // If nil, logging is done via the log package's standard logger.
+ ErrorLog *log.Logger
+
+ // ErrorHandler is an optional function that handles errors
+ // reaching the backend or errors from ModifyResponse.
+ //
+ // If nil, the default is to log the provided error and return
+ // a 502 Status Bad Gateway response.
+ ErrorHandler func(http.ResponseWriter, *http.Request, error)
+}
+
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
+ }
+ }
+}
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// As of RFC 7230, hop-by-hop headers are required to appear in the
+// Connection header field. These are the headers defined by the
+// obsoleted RFC 2616 (section 13.5.1) and are used for backward
+// compatibility.
+var hopHeaders = []string{
+ "Connection",
+ "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "Te", // canonicalized version of "TE"
+ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusBadGateway)
+}
+
+func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
+ if p.ErrorHandler != nil {
+ return p.ErrorHandler
+ }
+ return p.defaultErrorHandler
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ transport := p.Transport
+ ctx := req.Context()
+
+ outreq := req.Clone(ctx)
+ if req.ContentLength == 0 {
+ outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+ }
+ if outreq.Body != nil {
+ // Reading from the request body after returning from a handler is not
+ // allowed, and the RoundTrip goroutine that reads the Body can outlive
+ // this handler. This can lead to a crash if the handler panics (see
+ // Issue 46866). Although calling Close doesn't guarantee there isn't
+ // any Read in flight after the handle returns, in practice it's safe to
+ // read after closing it.
+ defer outreq.Body.Close()
+ }
+ if outreq.Header == nil {
+ outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
+ }
+
+ p.Director(outreq)
+ outreq.Close = false
+
+ reqUpType := upgradeType(outreq.Header)
+ if !IsPrint(reqUpType) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
+ return
+ }
+
+ req.Header.Del("Forwarded")
+ removeHopByHopHeaders(outreq.Header)
+
+ // Issue 21096: tell backend applications that care about trailer support
+ // that we support trailers. (We do, but we don't go out of our way to
+ // advertise that unless the incoming client request thought it was worth
+ // mentioning.) Note that we look at req.Header, not outreq.Header, since
+ // the latter has passed through removeHopByHopHeaders.
+ if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
+ outreq.Header.Set("Te", "trailers")
+ }
+
+ if _, ok := outreq.Header["User-Agent"]; !ok {
+ // If the outbound request doesn't have a User-Agent header set,
+ // don't send the default Go HTTP client User-Agent.
+ outreq.Header.Set("User-Agent", "")
+ }
+
+ var (
+ roundTripMutex sync.Mutex
+ roundTripDone bool
+ )
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ roundTripMutex.Lock()
+ defer roundTripMutex.Unlock()
+ if roundTripDone {
+ // If RoundTrip has returned, don't try to further modify
+ // the ResponseWriter's header map.
+ return nil
+ }
+ h := rw.Header()
+ copyHeader(h, http.Header(header))
+ rw.WriteHeader(code)
+
+ // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
+ clear(h)
+ return nil
+ },
+ }
+ outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
+
+ res, err := transport.RoundTrip(outreq)
+ roundTripMutex.Lock()
+ roundTripDone = true
+ roundTripMutex.Unlock()
+ if err != nil {
+ p.getErrorHandler()(rw, outreq, err)
+ return
+ }
+
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ p.handleUpgradeResponse(rw, outreq, res)
+ return
+ }
+
+ removeHopByHopHeaders(res.Header)
+
+ copyHeader(rw.Header(), res.Header)
+
+ // The "Trailer" header isn't included in the Transport's response,
+ // at least for *http.Transport. Build it up from Trailer.
+ announcedTrailers := len(res.Trailer)
+ if announcedTrailers > 0 {
+ trailerKeys := make([]string, 0, len(res.Trailer))
+ for k := range res.Trailer {
+ trailerKeys = append(trailerKeys, k)
+ }
+ rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
+ }
+
+ rw.WriteHeader(res.StatusCode)
+
+ err = utils.CopyCloseWithContext(ctx, rw, res.Body)
+ if err != nil {
+ if !errors.Is(err, context.Canceled) {
+ p.getErrorHandler()(rw, req, err)
+ }
+ return
+ }
+
+ if len(res.Trailer) > 0 {
+ // Force chunking if we saw a response trailer.
+ // This prevents net/http from calculating the length for short
+ // bodies and adding a Content-Length.
+ http.NewResponseController(rw).Flush()
+ }
+
+ if len(res.Trailer) == announcedTrailers {
+ copyHeader(rw.Header(), res.Trailer)
+ return
+ }
+
+ for k, vv := range res.Trailer {
+ k = http.TrailerPrefix + k
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+}
+
+// removeHopByHopHeaders removes hop-by-hop headers.
+func removeHopByHopHeaders(h http.Header) {
+ // RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
+ for _, f := range h["Connection"] {
+ for sf := range strings.SplitSeq(f, ",") {
+ if sf = textproto.TrimString(sf); sf != "" {
+ h.Del(sf)
+ }
+ }
+ }
+ // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
+ // This behavior is superseded by the RFC 7230 Connection header, but
+ // preserve it for backwards compatibility.
+ for _, f := range hopHeaders {
+ h.Del(f)
+ }
+}
+
+func (p *ReverseProxy) logf(format string, args ...any) {
+ if p.ErrorLog != nil {
+ p.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+func upgradeType(h http.Header) string {
+ if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
+ return ""
+ }
+ return h.Get("Upgrade")
+}
+
+func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
+ return
+ }
+ if !strings.EqualFold(reqUpType, resUpType) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+
+ rc := http.NewResponseController(rw)
+ conn, brw, hijackErr := rc.Hijack()
+ if errors.Is(hijackErr, http.ErrNotSupported) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+
+ backConnCloseCh := make(chan bool)
+ go func() {
+ // Ensure that the cancellation of a request closes the backend.
+ // See issue https://golang.org/issue/35559.
+ select {
+ case <-req.Context().Done():
+ case <-backConnCloseCh:
+ }
+ backConn.Close()
+ }()
+ defer close(backConnCloseCh)
+
+ if hijackErr != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
+ return
+ }
+ defer conn.Close()
+
+ copyHeader(rw.Header(), res.Header)
+
+ res.Header = rw.Header()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}
+
+func IsPrint(s string) bool {
+ for _, r := range s {
+ if r < ' ' || r > '~' {
+ return false
+ }
+ }
+ return true
+}