fix: optimize memory usage, fix agent and code refactor (#118)
Some checks are pending
Docker Image CI (socket-proxy) / build (push) Waiting to run

* refactor: simplify io code and make utils module independent

* fix(docker): agent and socket-proxy docker event flushing with modified reverse proxy handler

* refactor: remove unused code

* refactor: remove the use of logging module in most code

* refactor: streamline domain mismatch check in certState function

* tweak: use ecdsa p-256 for autocert

* fix(tests): update health check tests for invalid host and add case for port in host

* feat(acme): custom acme directory

* refactor: code refactor and improved context and error handling

* tweak: optimize memory usage under load

* fix(oidc): restore old user matching behavior

* docs: add ChatGPT assistant to README

---------

Co-authored-by: yusing <yusing@6uo.me>
This commit is contained in:
Yuzerion 2025-05-25 09:45:57 +08:00 committed by GitHub
parent ff08c40403
commit 4a8bd48ad5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
98 changed files with 1549 additions and 555 deletions

View file

@ -2,15 +2,16 @@ version: "2"
linters:
default: all
disable:
- bodyclose
# - bodyclose
- containedctx
- contextcheck
# - contextcheck
- cyclop
- depguard
- dupl
# - dupl
- err113
- exhaustive
- exhaustruct
- funcorder
- forcetypeassert
- gochecknoglobals
- gochecknoinits
@ -18,7 +19,6 @@ linters:
- goconst
- gocyclo
- gomoddirectives
- gosec
- gosmopolitan
- ireturn
- lll
@ -27,12 +27,10 @@ linters:
- mnd
- nakedret
- nestif
- nilnil
- nlreturn
- noctx
- nonamedreturns
- paralleltest
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- tagliatelle

View file

@ -21,7 +21,7 @@ lint:
- markdownlint
- yamllint
enabled:
- checkov@3.2.416
- checkov@3.2.432
- golangci-lint2@2.1.6
- hadolint@2.12.1-beta
- actionlint@1.7.7
@ -32,7 +32,7 @@ lint:
- prettier@3.5.3
- shellcheck@0.10.0
- shfmt@3.6.0
- trufflehog@3.88.29
- trufflehog@3.88.33
actions:
disabled:
- trunk-announce

View file

@ -16,6 +16,8 @@ A lightweight, simple, and performant reverse proxy with WebUI.
<h5>EN | <a href="README_CHT.md">中文</a></h5>
Have questions? Ask [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)! (Thanks to [@ismesid](https://github.com/arevindh))
<img src="screenshots/webui.jpg" style="max-width: 650">
</div>

View file

@ -16,6 +16,8 @@
<h5><a href="README.md">EN</a> | 中文</h5>
有疑問? 問 [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)!(鳴謝 [@ismesid](https://github.com/arevindh)
<img src="https://github.com/user-attachments/assets/4bb371f4-6e4c-425c-89b2-b9e962bdd46f" style="max-width: 650">
</div>

View file

@ -1,11 +1,14 @@
package main
import (
"os"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/server"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/task"
@ -14,6 +17,12 @@ import (
)
func main() {
writer := zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "01-02 15:04",
}
zerolog.TimeFieldFormat = writer.TimeFormat
log.Logger = zerolog.New(writer).Level(zerolog.InfoLevel).With().Timestamp().Logger()
ca := &agent.PEMPair{}
err := ca.Load(env.AgentCACert)
if err != nil {
@ -34,11 +43,11 @@ func main() {
gperr.LogFatal("init SSL error", err)
}
logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
logging.Info().Msgf("Agent name: %s", env.AgentName)
logging.Info().Msgf("Agent port: %d", env.AgentPort)
log.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
log.Info().Msgf("Agent name: %s", env.AgentName)
log.Info().Msgf("Agent port: %d", env.AgentPort)
logging.Info().Msg(`
log.Info().Msg(`
Tips:
1. To change the agent name, you can set the AGENT_NAME environment variable.
2. To change the agent port, you can set the AGENT_PORT environment variable.
@ -54,7 +63,7 @@ Tips:
server.StartAgentServer(t, opts)
if socketproxy.ListenAddr != "" {
logging.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
log.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
opts := httpServer.Options{
Name: "docker",
HTTPAddr: socketproxy.ListenAddr,

View file

@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy => ..
replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy
replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250523121925-f87c3159e327
@ -15,6 +17,7 @@ require (
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
github.com/yusing/go-proxy/internal/utils v0.0.0
github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000
)

View file

@ -15,8 +15,8 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/pkg"
)
@ -115,7 +115,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
cfg.name = string(name)
cfg.l = logging.With().Str("agent", cfg.name).Logger()
cfg.l = log.With().Str("agent", cfg.name).Logger()
// check agent version
agentVersionBytes, _, err := cfg.Fetch(ctx, EndpointVersion)
@ -127,10 +127,10 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
agentVersion := pkg.ParseVersion(cfg.version)
if serverVersion.IsNewerMajorThan(agentVersion) {
logging.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
}
logging.Info().Msgf("agent %q initialized", cfg.name)
log.Info().Msgf("agent %q initialized", cfg.name)
return nil
}

View file

@ -172,9 +172,9 @@ func TestCheckHealthTCPUDP(t *testing.T) {
{
name: "InvalidHost",
scheme: "tcp",
host: "invalid",
host: "",
port: 8080,
expectedStatus: http.StatusOK,
expectedStatus: http.StatusBadRequest,
expectedHealthy: false,
},
{
@ -188,9 +188,17 @@ func TestCheckHealthTCPUDP(t *testing.T) {
{
name: "InvalidHost",
scheme: "udp",
host: "invalid",
host: "",
port: 8080,
expectedStatus: http.StatusOK,
expectedStatus: http.StatusBadRequest,
expectedHealthy: false,
},
{
name: "Port in both host and port",
scheme: "tcp",
host: "localhost:1234",
port: 1234,
expectedStatus: http.StatusBadRequest,
expectedHealthy: false,
},
}
@ -208,9 +216,11 @@ func TestCheckHealthTCPUDP(t *testing.T) {
require.Equal(t, recorder.Code, tt.expectedStatus)
var result health.HealthCheckResult
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
require.Equal(t, result.Healthy, tt.expectedHealthy)
if tt.expectedStatus == http.StatusOK {
var result health.HealthCheckResult
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
require.Equal(t, result.Healthy, tt.expectedHealthy)
}
})
}
}

View file

@ -1,17 +1,14 @@
package handler
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"time"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/pkg"
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
)
type ServeMux struct{ *http.ServeMux }
@ -24,26 +21,6 @@ func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
}
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", env.DockerSocket)
}
func dockerSocketHandler() http.HandlerFunc {
rp := httputil.ReverseProxy{
Director: func(r *http.Request) {
r.URL.Scheme = "http"
r.URL.Host = "api.moby.localhost"
r.RequestURI = r.URL.String()
},
Transport: &http.Transport{
DialContext: dialDockerSocket,
},
}
return rp.ServeHTTP
}
func NewAgentHandler() http.Handler {
mux := ServeMux{http.NewServeMux()}
@ -54,6 +31,6 @@ func NewAgentHandler() http.Handler {
})
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", dockerSocketHandler())
mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler(env.DockerSocket))
return mux
}

View file

@ -6,9 +6,9 @@ import (
"fmt"
"net/http"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/handler"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/task"
)
@ -33,12 +33,11 @@ func StartAgentServer(parent task.Parent, opt Options) {
tlsConfig.ClientAuth = tls.NoClientCert
}
logger := logging.GetLogger()
agentServer := &http.Server{
Addr: fmt.Sprintf(":%d", opt.Port),
Handler: handler.NewAgentHandler(),
TLSConfig: tlsConfig,
}
server.Start(parent, agentServer, nil, logger)
server.Start(parent, agentServer, nil, &log.Logger)
}

View file

@ -4,6 +4,7 @@ import (
"os"
"sync"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/auth"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config"
@ -35,8 +36,8 @@ func main() {
initProfiling()
logging.InitLogger(os.Stderr, memlogger.GetMemLogger())
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
logging.Trace().Msg("trace enabled")
log.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
log.Trace().Msg("trace enabled")
parallel(
dnsproviders.InitProviders,
homepage.InitIconListCache,
@ -45,7 +46,7 @@ func main() {
)
if common.APIJWTSecret == nil {
logging.Warn().Msg("API_JWT_SECRET is not set, using random key")
log.Warn().Msg("API_JWT_SECRET is not set, using random key")
common.APIJWTSecret = common.RandomJWTKey()
}
@ -62,7 +63,7 @@ func main() {
Proxy: true,
})
if err := auth.Initialize(); err != nil {
logging.Fatal().Err(err).Msg("failed to initialize authentication")
log.Fatal().Err(err).Msg("failed to initialize authentication")
}
// API Handler needs to start after auth is initialized.
cfg.StartServers(&config.StartServersOptions{
@ -78,7 +79,7 @@ func main() {
func prepareDirectory(dir string) {
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err = os.MkdirAll(dir, 0o755); err != nil {
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
log.Fatal().Msgf("failed to create directory %s: %v", dir, err)
}
}
}

7
go.mod
View file

@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy/agent => ./agent
replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders
replace github.com/yusing/go-proxy/internal/utils => ./internal/utils
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.0.0-20250523122447-f078841dec22
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
@ -45,7 +47,7 @@ require (
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy/agent v0.0.0-00010101000000-000000000000
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-00010101000000-000000000000
go.uber.org/atomic v1.11.0
github.com/yusing/go-proxy/internal/utils v0.0.0
)
require (
@ -219,6 +221,7 @@ require (
go.opentelemetry.io/otel v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/mock v0.5.2 // indirect
go.uber.org/multierr v1.11.0 // indirect
@ -226,7 +229,7 @@ require (
golang.org/x/mod v0.24.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0
golang.org/x/text v0.25.0 // indirect
golang.org/x/tools v0.33.0 // indirect
google.golang.org/api v0.234.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect

View file

@ -5,9 +5,9 @@ import (
"time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/task"
@ -45,7 +45,7 @@ func (c *checkCache) Expired() bool {
return c.created.Add(cacheTTL).Before(utils.TimeNow())
}
//TODO: add stats
// TODO: add stats
const (
ACLAllow = "allow"
@ -97,7 +97,7 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
if c.valErr != nil {
return c.valErr
}
logging.Info().
log.Info().
Str("default", c.Default).
Bool("allow_local", c.allowLocal).
Int("allow_rules", len(c.Allow)).

View file

@ -6,7 +6,7 @@ import (
"testing"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
func TestMatchers(t *testing.T) {
@ -16,7 +16,7 @@ func TestMatchers(t *testing.T) {
}
var mathers Matchers
err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
err := serialization.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
if err != nil {
t.Fatal(err)
}

View file

@ -22,12 +22,12 @@ func (noConn) SetDeadline(t time.Time) error { return nil }
func (noConn) SetReadDeadline(t time.Time) error { return nil }
func (noConn) SetWriteDeadline(t time.Time) error { return nil }
func (cfg *Config) WrapTCP(lis net.Listener) net.Listener {
if cfg == nil {
func (c *Config) WrapTCP(lis net.Listener) net.Listener {
if c == nil {
return lis
}
return &TCPListener{
acl: cfg,
acl: c,
lis: lis,
}
}

View file

@ -3,9 +3,9 @@ package certapi
import (
"net/http"
"github.com/rs/zerolog/log"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
)
@ -36,7 +36,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
gperr.LogError("failed to obtain cert", err)
_ = gpwebsocket.WriteText(conn, err.Error())
} else {
logging.Info().Msg("cert obtained successfully")
log.Info().Msg("cert obtained successfully")
}
}()
for {

View file

@ -9,12 +9,13 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/pkg/stdcopy"
"github.com/gorilla/websocket"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/task"
)
// FIXME: agent logs not updating.
func Logs(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
server := r.PathValue("server")
@ -68,7 +69,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) {
return
}
logging.Err(err).
log.Err(err).
Str("server", server).
Str("container", containerID).
Msg("failed to de-multiplex logs")

View file

@ -11,9 +11,9 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/jsonstore"
"github.com/yusing/go-proxy/internal/logging"
"golang.org/x/oauth2"
)
@ -108,11 +108,11 @@ func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
RefreshToken: token,
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
})
logging.Debug().Str("username", username).Msg("stored oauth refresh token")
log.Debug().Str("username", username).Msg("stored oauth refresh token")
}
func invalidateOAuthRefreshToken(sessionID sessionID) {
logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
log.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
oauthRefreshTokens.Delete(string(sessionID))
}
@ -127,7 +127,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
signed, err := jwtToken.SignedString(common.APIJWTSecret)
if err != nil {
logging.Err(err).Msg("failed to sign session token")
log.Err(err).Msg("failed to sign session token")
return
}
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
@ -190,7 +190,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
return nil, refreshToken.err
}
idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken)
idTokenJWT, idToken, err := auth.getIDToken(ctx, newToken)
if err != nil {
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
return nil, refreshToken.err
@ -205,7 +205,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
sessionID := newSessionID()
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
log.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
refreshToken.result = &RefreshResult{

View file

@ -12,9 +12,9 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/utils"
"golang.org/x/oauth2"
@ -38,8 +38,8 @@ type (
const (
CookieOauthState = "godoxy_oidc_state"
CookieOauthToken = "godoxy_oauth_token"
CookieOauthSessionToken = "godoxy_session_token"
CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
)
const (
@ -79,7 +79,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
if err != nil && provider.EndSessionEndpoint() != "" {
// non critical, just warn
logging.Warn().
log.Warn().
Str("issuer", issuerURL).
Err(err).
Msg("failed to parse end session URL")
@ -129,7 +129,7 @@ func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
}
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
func (auth *OIDCProvider) getIDToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
if !ok {
return "", nil, errMissingIDToken
@ -176,7 +176,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
return
}
// clear cookies then redirect to home
logging.Err(err).Msg("failed to refresh token")
log.Err(err).Msg("failed to refresh token")
auth.clearCookie(w, r)
http.Redirect(w, r, "/", http.StatusFound)
return
@ -201,11 +201,12 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
userAllowed := slices.Contains(auth.allowedUsers, user)
if !userAllowed {
return false
if userAllowed {
return true
}
if len(auth.allowedGroups) == 0 {
return true
// user is not allowed, but no groups are allowed
return false
}
return len(utils.Intersect(groups, auth.allowedGroups)) > 0
}
@ -257,7 +258,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
return
}
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token)
idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token)
if err != nil {
gphttp.ServerError(w, r, err)
return

View file

@ -5,13 +5,16 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"net/http"
"os"
"regexp"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
)
@ -22,13 +25,19 @@ type Config struct {
KeyPath string `json:"key_path,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"`
Provider string `json:"provider,omitempty"`
CADirURL string `json:"ca_dir_url,omitempty"`
Options map[string]any `json:"options,omitempty"`
HTTPClient *http.Client `json:"-"` // for tests only
challengeProvider challenge.Provider
}
var (
ErrMissingDomain = gperr.New("missing field 'domains'")
ErrMissingEmail = gperr.New("missing field 'email'")
ErrMissingProvider = gperr.New("missing field 'provider'")
ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
ErrInvalidDomain = gperr.New("invalid domain")
ErrUnknownProvider = gperr.New("unknown provider")
)
@ -36,6 +45,7 @@ var (
const (
ProviderLocal = "local"
ProviderPseudo = "pseudo"
ProviderCustom = "custom"
)
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
@ -52,6 +62,10 @@ func (cfg *Config) Validate() gperr.Error {
}
b := gperr.NewBuilder("autocert errors")
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
b.Add(ErrMissingCADirURL)
}
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
@ -59,24 +73,34 @@ func (cfg *Config) Validate() gperr.Error {
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
if cfg.Provider != ProviderCustom {
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
}
// check if provider is implemented
providerConstructor, ok := Providers[cfg.Provider]
if !ok {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
if cfg.Provider != ProviderCustom {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
}
} else {
_, err := providerConstructor(cfg.Options)
provider, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
} else {
cfg.challengeProvider = provider
}
}
}
if cfg.challengeProvider == nil {
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
}
return b.Error()
}
@ -100,8 +124,7 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if privKey, err = cfg.LoadACMEKey(); err != nil {
logging.Info().Err(err).Msg("load ACME private key failed")
logging.Info().Msg("generate new ACME private key")
log.Info().Err(err).Msg("failed to load ACME private key, generating a now one")
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, gperr.New("generate ACME private key").With(err)
@ -118,12 +141,23 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
}
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
legoCfg.Certificate.KeyType = certcrypto.EC256
if cfg.HTTPClient != nil {
legoCfg.HTTPClient = cfg.HTTPClient
}
if cfg.CADirURL != "" {
legoCfg.CADirURL = cfg.CADirURL
}
return user, legoCfg, nil
}
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
if common.IsTest {
return nil, os.ErrNotExist
}
data, err := os.ReadFile(cfg.ACMEKeyPath)
if err != nil {
return nil, err
@ -132,6 +166,9 @@ func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
}
func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error {
if common.IsTest {
return nil
}
data, err := x509.MarshalECPrivateKey(key)
if err != nil {
return err

View file

@ -5,18 +5,20 @@ import (
"crypto/x509"
"errors"
"fmt"
"maps"
"os"
"path"
"reflect"
"sort"
"slices"
"strings"
"time"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
@ -76,13 +78,11 @@ func (p *Provider) ObtainCert() error {
}
if p.cfg.Provider == ProviderPseudo {
t := time.NewTicker(1000 * time.Millisecond)
defer t.Stop()
logging.Info().Msg("init client for pseudo provider")
<-t.C
logging.Info().Msg("registering acme for pseudo provider")
<-t.C
logging.Info().Msg("obtained cert for pseudo provider")
log.Info().Msg("init client for pseudo provider")
<-time.After(time.Second)
log.Info().Msg("registering acme for pseudo provider")
<-time.After(time.Second)
log.Info().Msg("obtained cert for pseudo provider")
return nil
}
@ -107,7 +107,7 @@ func (p *Provider) ObtainCert() error {
})
if err != nil {
p.legoCert = nil
logging.Err(err).Msg("cert renew failed, fallback to obtain")
log.Err(err).Msg("cert renew failed, fallback to obtain")
} else {
p.legoCert = cert
}
@ -154,7 +154,7 @@ func (p *Provider) LoadCert() error {
p.tlsCert = &cert
p.certExpiries = expiries
logging.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
log.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
}
@ -219,13 +219,7 @@ func (p *Provider) initClient() error {
return err
}
generator := Providers[p.cfg.Provider]
legoProvider, pErr := generator(p.cfg.Options)
if pErr != nil {
return pErr
}
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider)
if err != nil {
return err
}
@ -240,7 +234,7 @@ func (p *Provider) registerACME() error {
}
if reg, err := p.client.Registration.ResolveAccountByKey(); err == nil {
p.user.Registration = reg
logging.Info().Msg("reused acme registration from private key")
log.Info().Msg("reused acme registration from private key")
return nil
}
@ -249,11 +243,14 @@ func (p *Provider) registerACME() error {
return err
}
p.user.Registration = reg
logging.Info().Interface("reg", reg).Msg("acme registered")
log.Info().Interface("reg", reg).Msg("acme registered")
return nil
}
func (p *Provider) saveCert(cert *certificate.Resource) error {
if common.IsTest {
return nil
}
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
@ -283,22 +280,19 @@ func (p *Provider) certState() CertState {
return CertStateExpired
}
certDomains := make([]string, len(p.certExpiries))
wantedDomains := make([]string, len(p.cfg.Domains))
i := 0
for domain := range p.certExpiries {
certDomains[i] = domain
i++
}
copy(wantedDomains, p.cfg.Domains)
sort.Strings(wantedDomains)
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logging.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
if len(p.certExpiries) != len(p.cfg.Domains) {
return CertStateMismatch
}
for i := range len(p.cfg.Domains) {
if _, ok := p.certExpiries[p.cfg.Domains[i]]; !ok {
log.Info().Msgf("autocert domains mismatch: cert: %s, wanted: %s",
strings.Join(slices.Collect(maps.Keys(p.certExpiries)), ", "),
strings.Join(p.cfg.Domains, ", "))
return CertStateMismatch
}
}
return CertStateValid
}
@ -309,9 +303,9 @@ func (p *Provider) renewIfNeeded() error {
switch p.certState() {
case CertStateExpired:
logging.Info().Msg("certs expired, renewing")
log.Info().Msg("certs expired, renewing")
case CertStateMismatch:
logging.Info().Msg("cert domains mismatch with config, renewing")
log.Info().Msg("cert domains mismatch with config, renewing")
default:
return nil
}

View file

@ -0,0 +1,453 @@
//nolint:errchkjson,errcheck
package provider_test
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/dnsproviders"
)
func TestMain(m *testing.M) {
dnsproviders.InitProviders()
m.Run()
}
func TestCustomProvider(t *testing.T) {
t.Run("valid custom provider with step-ca", func(t *testing.T) {
cfg := &autocert.Config{
Email: "test@example.com",
Domains: []string{"example.com", "*.example.com"},
Provider: autocert.ProviderCustom,
CADirURL: "https://ca.example.com:9000/acme/acme/directory",
CertPath: "certs/custom.crt",
KeyPath: "certs/custom.key",
ACMEKeyPath: "certs/custom-acme.key",
}
err := cfg.Validate()
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
require.NoError(t, err)
require.NotNil(t, user)
require.NotNil(t, legoCfg)
require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", legoCfg.CADirURL)
require.Equal(t, "test@example.com", user.Email)
})
t.Run("custom provider missing CADirURL", func(t *testing.T) {
cfg := &autocert.Config{
Email: "test@example.com",
Domains: []string{"example.com"},
Provider: autocert.ProviderCustom,
// CADirURL is missing
}
err := cfg.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
})
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
cfg := &autocert.Config{
Email: "admin@internal.com",
Domains: []string{"internal.example.com", "api.internal.example.com"},
Provider: autocert.ProviderCustom,
CADirURL: "https://step-ca.internal:443/acme/acme/directory",
CertPath: "certs/internal.crt",
KeyPath: "certs/internal.key",
ACMEKeyPath: "certs/internal-acme.key",
}
err := cfg.Validate()
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
require.NoError(t, err)
require.NotNil(t, user)
require.NotNil(t, legoCfg)
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
require.Equal(t, "admin@internal.com", user.Email)
provider := autocert.NewProvider(cfg, user, legoCfg)
require.NotNil(t, provider)
require.Equal(t, autocert.ProviderCustom, provider.GetName())
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
})
}
func TestObtainCertFromCustomProvider(t *testing.T) {
// Create a test ACME server
acmeServer := newTestACMEServer(t)
defer acmeServer.Close()
t.Run("obtain cert from custom step-ca server", func(t *testing.T) {
cfg := &autocert.Config{
Email: "test@example.com",
Domains: []string{"test.example.com"},
Provider: autocert.ProviderCustom,
CADirURL: acmeServer.URL() + "/acme/acme/directory",
CertPath: "certs/stepca-test.crt",
KeyPath: "certs/stepca-test.key",
ACMEKeyPath: "certs/stepca-test-acme.key",
HTTPClient: acmeServer.httpClient(),
}
err := error(cfg.Validate())
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
require.NoError(t, err)
require.NotNil(t, user)
require.NotNil(t, legoCfg)
provider := autocert.NewProvider(cfg, user, legoCfg)
require.NotNil(t, provider)
// Test obtaining certificate
err = provider.ObtainCert()
require.NoError(t, err)
// Verify certificate was obtained
cert, err := provider.GetCert(nil)
require.NoError(t, err)
require.NotNil(t, cert)
// Verify certificate properties
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, x509Cert.DNSNames, "test.example.com")
require.True(t, time.Now().Before(x509Cert.NotAfter))
require.True(t, time.Now().After(x509Cert.NotBefore))
})
}
// testACMEServer implements a minimal ACME server for testing.
type testACMEServer struct {
server *httptest.Server
caCert *x509.Certificate
caKey *rsa.PrivateKey
clientCSRs map[string]*x509.CertificateRequest
orderID string
}
func newTestACMEServer(t *testing.T) *testACMEServer {
t.Helper()
// Generate CA certificate and key
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caCert, err := x509.ParseCertificate(caCertDER)
require.NoError(t, err)
acme := &testACMEServer{
caCert: caCert,
caKey: caKey,
clientCSRs: make(map[string]*x509.CertificateRequest),
orderID: "test-order-123",
}
mux := http.NewServeMux()
acme.setupRoutes(mux)
acme.server = httptest.NewTLSServer(mux)
return acme
}
func (s *testACMEServer) Close() {
s.server.Close()
}
func (s *testACMEServer) URL() string {
return s.server.URL
}
func (s *testACMEServer) httpClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, //nolint:gosec
},
},
}
}
func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
// ACME directory endpoint
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
// ACME endpoints
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
mux.HandleFunc("/acme/chall/", s.handleChallenge)
mux.HandleFunc("/acme/order/", s.handleOrder)
mux.HandleFunc("/acme/cert/", s.handleCertificate)
}
func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
directory := map[string]interface{}{
"newNonce": s.server.URL + "/acme/new-nonce",
"newAccount": s.server.URL + "/acme/new-account",
"newOrder": s.server.URL + "/acme/new-order",
"keyChange": s.server.URL + "/acme/key-change",
"meta": map[string]interface{}{
"termsOfService": s.server.URL + "/terms",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(directory)
}
func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "test-nonce-12345")
w.WriteHeader(http.StatusOK)
}
func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
account := map[string]interface{}{
"status": "valid",
"contact": []string{"mailto:test@example.com"},
"orders": s.server.URL + "/acme/orders",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/account/1")
w.Header().Set("Replay-Nonce", "test-nonce-67890")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(account)
}
func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
authzID := "test-authz-456"
order := map[string]interface{}{
"status": "ready", // Skip pending state for simplicity
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
"finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
w.Header().Set("Replay-Nonce", "test-nonce-order")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authz := map[string]interface{}{
"status": "valid", // Skip challenge validation for simplicity
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifier": map[string]string{"type": "dns", "value": "test.example.com"},
"challenges": []map[string]interface{}{
{
"type": "dns-01",
"status": "valid",
"url": s.server.URL + "/acme/chall/test-chall-789",
"token": "test-token-abc123",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-authz")
json.NewEncoder(w).Encode(authz)
}
func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
challenge := map[string]interface{}{
"type": "dns-01",
"status": "valid",
"url": r.URL.String(),
"token": "test-token-abc123",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-chall")
json.NewEncoder(w).Encode(challenge)
}
func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/finalize") {
s.handleFinalize(w, r)
return
}
certURL := s.server.URL + "/acme/cert/" + s.orderID
order := map[string]interface{}{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
// Read the JWS payload
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request", http.StatusBadRequest)
return
}
// Extract CSR from JWS payload
csr, err := s.extractCSRFromJWS(body)
if err != nil {
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
return
}
// Store the CSR for certificate generation
s.clientCSRs[s.orderID] = csr
certURL := s.server.URL + "/acme/cert/" + s.orderID
order := map[string]interface{}{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
// Parse the JWS structure
var jws struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
if err := json.Unmarshal(jwsData, &jws); err != nil {
return nil, err
}
// Decode the payload
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, err
}
// Parse the finalize request
var finalizeReq struct {
CSR string `json:"csr"`
}
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
return nil, err
}
// Decode the CSR
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
if err != nil {
return nil, err
}
// Parse the CSR
csr, err := x509.ParseCertificateRequest(csrBytes)
if err != nil {
return nil, err
}
return csr, nil
}
func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
// Extract order ID from URL
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
// Get the CSR for this order
csr, exists := s.clientCSRs[orderID]
if !exists {
http.Error(w, "No CSR found for order", http.StatusBadRequest)
return
}
// Create certificate using the public key from the client's CSR
template := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Cert"},
Country: []string{"US"},
},
DNSNames: csr.DNSNames,
NotBefore: time.Now(),
NotAfter: time.Now().Add(90 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
// Use the public key from the CSR and sign with CA key
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Return certificate chain
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
w.Header().Set("Content-Type", "application/pem-certificate-chain")
w.Header().Set("Replay-Nonce", "test-nonce-cert")
w.Write(append(certPEM, caPEM...))
}

View file

@ -6,7 +6,7 @@ import (
"github.com/go-acme/lego/v4/providers/dns/ovh"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
// type Config struct {
@ -45,6 +45,6 @@ oauth2_config:
testYaml = testYaml[1:] // remove first \n
opt := make(map[string]any)
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
require.NoError(t, serialization.MapUnmarshalValidate(opt, cfg))
require.Equal(t, cfgExpected, cfg)
}

View file

@ -3,7 +3,7 @@ package autocert
import (
"github.com/go-acme/lego/v4/challenge"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type Generator func(map[string]any) (challenge.Provider, gperr.Error)
@ -16,9 +16,11 @@ func DNSProvider[CT any, PT challenge.Provider](
) Generator {
return func(opt map[string]any) (challenge.Provider, gperr.Error) {
cfg := defaultCfg()
err := utils.MapUnmarshalValidate(opt, &cfg)
if err != nil {
return nil, err
if len(opt) > 0 {
err := serialization.MapUnmarshalValidate(opt, &cfg)
if err != nil {
return nil, err
}
}
p, pErr := newProvider(cfg)
return p, gperr.Wrap(pErr)

View file

@ -4,7 +4,7 @@ import (
"errors"
"os"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -13,14 +13,14 @@ func (p *Provider) Setup() (err error) {
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
logging.Debug().Msg("obtaining cert due to error loading cert")
log.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
}
for _, expiry := range p.GetExpiries() {
logging.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
break
}

View file

@ -10,20 +10,20 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/api"
autocert "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/entrypoint"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/maxmind"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
"github.com/yusing/go-proxy/internal/watcher"
@ -96,10 +96,10 @@ func OnConfigChange(ev []events.Event) {
// just reload once and check the last event
switch ev[len(ev)-1].Action {
case events.ActionFileRenamed:
logging.Warn().Msg(cfgRenameWarn)
log.Warn().Msg(cfgRenameWarn)
return
case events.ActionFileDeleted:
logging.Warn().Msg(cfgDeleteWarn)
log.Warn().Msg(cfgDeleteWarn)
return
}
@ -161,7 +161,7 @@ func (cfg *Config) Start(opts ...*StartServersOptions) {
func (cfg *Config) StartAutoCert() {
autocert := cfg.autocertProvider
if autocert == nil {
logging.Info().Msg("autocert not configured")
log.Info().Msg("autocert not configured")
return
}
@ -223,7 +223,7 @@ func (cfg *Config) load() gperr.Error {
}
model := config.DefaultConfig()
if err := utils.UnmarshalValidateYAML(data, model); err != nil {
if err := serialization.UnmarshalValidateYAML(data, model); err != nil {
gperr.LogFatal(errMsg, err)
}
@ -374,6 +374,6 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
}
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
})
logging.Info().Msg(results.String())
log.Info().Msg(results.String())
return errs.Error()
}

View file

@ -14,7 +14,7 @@ import (
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type (
@ -93,14 +93,14 @@ func HasInstance() bool {
func Validate(data []byte) gperr.Error {
var model Config
return utils.UnmarshalValidateYAML(data, &model)
return serialization.UnmarshalValidateYAML(data, &model)
}
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
func init() {
utils.RegisterDefaultValueFactory(DefaultConfig)
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
serialization.RegisterDefaultValueFactory(DefaultConfig)
serialization.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
domains := fl.Field().Interface().([]string)
for _, domain := range domains {
if !matchDomainsRegex.MatchString(domain) {
@ -109,7 +109,7 @@ func init() {
}
return true
})
utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
serialization.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
m := fl.Field().Interface().(map[string]string)
for k := range m {
if k == "" {

View file

@ -4,6 +4,8 @@ go 1.24.3
replace github.com/yusing/go-proxy => ../..
replace github.com/yusing/go-proxy/internal/utils => ../utils
require (
github.com/go-acme/lego/v4 v4.23.1
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
@ -156,6 +158,7 @@ require (
github.com/vultr/govultr/v3 v3.20.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusing/go-proxy/internal/utils v0.0.0 // indirect
go.mongodb.org/mongo-driver v1.17.3 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect

View file

@ -13,13 +13,14 @@ import (
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
// TODO: implement reconnect here.
type (
SharedClient struct {
*client.Client
@ -83,7 +84,7 @@ func closeTimedOutClients() {
if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs {
delete(clientMap, c.Key())
c.Client.Close()
logging.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
log.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
}
}
}
@ -148,7 +149,7 @@ func NewClient(host string) (*SharedClient, error) {
default:
helper, err := connhelper.GetConnectionHelper(host)
if err != nil {
logging.Panic().Err(err).Msg("failed to get connection helper")
log.Panic().Err(err).Msg("failed to get connection helper")
}
if helper != nil {
httpClient := &http.Client{
@ -189,10 +190,10 @@ func NewClient(host string) (*SharedClient, error) {
c.dial = client.Dialer()
}
if c.addr == "" {
c.addr = c.Client.DaemonHost()
c.addr = c.DaemonHost()
}
defer logging.Debug().Str("host", host).Msg("docker client initialized")
defer log.Debug().Str("host", host).Msg("docker client initialized")
clientMap[c.Key()] = c
return c, nil

View file

@ -9,11 +9,12 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/docker/go-connections/nat"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/utils"
)
@ -90,7 +91,7 @@ func FromDocker(c *container.SummaryTrimmed, dockerHost string) (res *Container)
var ok bool
res.Agent, ok = config.GetInstance().GetAgent(dockerHost)
if !ok {
logging.Error().Msgf("agent %q not found", dockerHost)
log.Error().Msgf("agent %q not found", dockerHost)
}
}
@ -183,7 +184,7 @@ func (c *Container) setPublicHostname() {
}
url, err := url.Parse(c.DockerHost)
if err != nil {
logging.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
log.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
c.PublicHostname = "127.0.0.1"
return
}
@ -224,7 +225,7 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
ContainerName: c.ContainerName,
},
}
err := utils.MapUnmarshalValidate(cfg, idwCfg)
err := serialization.MapUnmarshalValidate(cfg, idwCfg)
if err != nil {
gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err))
} else {

View file

@ -6,7 +6,7 @@ import (
"net/http"
"strings"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/logging/accesslog"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
@ -50,7 +50,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
}
ep.middleware = mid
logging.Debug().Msg("entrypoint middleware loaded")
log.Debug().Msg("entrypoint middleware loaded")
return nil
}
@ -64,7 +64,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request
if err != nil {
return
}
logging.Debug().Msg("entrypoint access logger created")
log.Debug().Msg("entrypoint access logger created")
return
}
@ -89,7 +89,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
logging.Err(err).
log.Err(err).
Str("method", r.Method).
Str("url", r.URL.String()).
Str("remote", r.RemoteAddr).
@ -99,7 +99,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if _, err := w.Write(errorPage); err != nil {
logging.Err(err).Msg("failed to write error page")
log.Err(err).Msg("failed to write error page")
}
} else {
http.Error(w, err.Error(), http.StatusNotFound)

View file

@ -4,8 +4,8 @@ import (
"os"
"github.com/rs/zerolog"
zerologlog "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
)
func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
@ -13,7 +13,7 @@ func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger)
if len(logger) > 0 {
l = logger[0]
} else {
l = logging.GetLogger()
l = &zerologlog.Logger
}
l.WithLevel(level).Msg(New(highlightANSI(msg)).With(err).Error())
switch level {

View file

@ -5,7 +5,7 @@ import (
"slices"
"github.com/yusing/go-proxy/internal/homepage/widgets"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type (
@ -32,7 +32,7 @@ type (
)
func init() {
utils.RegisterDefaultValueFactory(func() *ItemConfig {
serialization.RegisterDefaultValueFactory(func() *ItemConfig {
return &ItemConfig{
Show: true,
}

View file

@ -6,9 +6,9 @@ import (
"sync"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/jsonstore"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/atomic"
@ -74,7 +74,7 @@ func pruneExpiredIconCache() {
}
}
if nPruned > 0 {
logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
log.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
}
}
@ -87,7 +87,7 @@ func loadIconCache(key string) *FetchResult {
defer iconMu.RUnlock()
icon, ok := iconCache.Load(key)
if ok && len(icon.Icon) > 0 {
logging.Debug().
log.Debug().
Str("key", key).
Msg("icon found in cache")
icon.LastAccess.Store(utils.TimeNow())
@ -99,7 +99,7 @@ func loadIconCache(key string) *FetchResult {
func storeIconCache(key string, result *FetchResult) {
icon := result.Icon
if len(icon) > maxIconSize {
logging.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
log.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
return
}
@ -109,7 +109,7 @@ func storeIconCache(key string, result *FetchResult) {
entry := &cacheEntry{Icon: icon, ContentType: result.contentType}
entry.LastAccess.Store(time.Now())
iconCache.Store(key, entry)
logging.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
log.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
}
func (e *cacheEntry) IsExpired() bool {

View file

@ -1,6 +1,7 @@
package homepage
import (
"context"
"encoding/json"
"fmt"
"io"
@ -10,10 +11,10 @@ import (
"time"
"github.com/lithammer/fuzzysearch/fuzzy"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -46,30 +47,30 @@ type (
func (icon *IconMeta) Filenames(ref string) []string {
filenames := make([]string, 0)
if icon.SVG {
filenames = append(filenames, fmt.Sprintf("%s.svg", ref))
filenames = append(filenames, ref+".svg")
if icon.Light {
filenames = append(filenames, fmt.Sprintf("%s-light.svg", ref))
filenames = append(filenames, ref+"-light.svg")
}
if icon.Dark {
filenames = append(filenames, fmt.Sprintf("%s-dark.svg", ref))
filenames = append(filenames, ref+"-dark.svg")
}
}
if icon.PNG {
filenames = append(filenames, fmt.Sprintf("%s.png", ref))
filenames = append(filenames, ref+".png")
if icon.Light {
filenames = append(filenames, fmt.Sprintf("%s-light.png", ref))
filenames = append(filenames, ref+"-light.png")
}
if icon.Dark {
filenames = append(filenames, fmt.Sprintf("%s-dark.png", ref))
filenames = append(filenames, ref+"-dark.png")
}
}
if icon.WebP {
filenames = append(filenames, fmt.Sprintf("%s.webp", ref))
filenames = append(filenames, ref+".webp")
if icon.Light {
filenames = append(filenames, fmt.Sprintf("%s-light.webp", ref))
filenames = append(filenames, ref+"-light.webp")
}
if icon.Dark {
filenames = append(filenames, fmt.Sprintf("%s-dark.webp", ref))
filenames = append(filenames, ref+"-dark.webp")
}
}
return filenames
@ -99,21 +100,21 @@ func InitIconListCache() {
iconsCache.Lock()
defer iconsCache.Unlock()
err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache)
err := serialization.LoadJSONIfExist(common.IconListCachePath, iconsCache)
if err != nil {
logging.Error().Err(err).Msg("failed to load icons")
log.Error().Err(err).Msg("failed to load icons")
} else if len(iconsCache.Icons) > 0 {
logging.Info().
log.Info().
Int("icons", len(iconsCache.Icons)).
Msg("icons loaded")
}
if err = updateIcons(); err != nil {
logging.Error().Err(err).Msg("failed to update icons")
log.Error().Err(err).Msg("failed to update icons")
}
task.OnProgramExit("save_icons_cache", func() {
utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
_ = serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
})
}
@ -134,17 +135,17 @@ func ListAvailableIcons() (*Cache, error) {
iconsCache.Lock()
defer iconsCache.Unlock()
logging.Info().Msg("updating icon data")
log.Info().Msg("updating icon data")
if err := updateIcons(); err != nil {
return nil, err
}
logging.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
log.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
iconsCache.LastUpdate = time.Now()
err := utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
err := serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
if err != nil {
logging.Warn().Err(err).Msg("failed to save icons")
log.Warn().Err(err).Msg("failed to save icons")
}
return iconsCache, nil
}
@ -230,14 +231,17 @@ func updateIcons() error {
var httpGet = httpGetImpl
func MockHttpGet(body []byte) {
func MockHTTPGet(body []byte) {
httpGet = func(_ string) ([]byte, error) {
return body, nil
}
}
func httpGetImpl(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
@ -347,7 +351,7 @@ func UpdateSelfhstIcons() error {
}
data := make([]SelfhStIcon, 0)
err = json.Unmarshal(body, &data)
err = json.Unmarshal(body, &data) //nolint:musttag
if err != nil {
return err
}

View file

@ -68,6 +68,8 @@ type testCases struct {
}
func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
t.Helper()
for _, item := range test {
icon, ok := iconsCache.Icons[item.Key]
if !ok {
@ -89,7 +91,7 @@ func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
}
func TestListWalkxCodeIcons(t *testing.T) {
MockHttpGet([]byte(walkxcodeIcons))
MockHTTPGet([]byte(walkxcodeIcons))
if err := UpdateWalkxCodeIcons(); err != nil {
t.Fatal(err)
}
@ -122,7 +124,7 @@ func TestListWalkxCodeIcons(t *testing.T) {
}
func TestListSelfhstIcons(t *testing.T) {
MockHttpGet([]byte(selfhstIcons))
MockHTTPGet([]byte(selfhstIcons))
if err := UpdateSelfhstIcons(); err != nil {
t.Fatal(err)
}

View file

@ -4,7 +4,7 @@ import (
"context"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type (
@ -33,17 +33,18 @@ var widgetProviders = map[string]struct{}{
var ErrInvalidProvider = gperr.New("invalid provider")
func (cfg *Config) UnmarshalMap(m map[string]any) error {
cfg.Provider = m["provider"].(string)
var ok bool
cfg.Provider, ok = m["provider"].(string)
if !ok {
return ErrInvalidProvider.Withf("non string")
}
if _, ok := widgetProviders[cfg.Provider]; !ok {
return ErrInvalidProvider.Subject(cfg.Provider)
}
delete(m, "provider")
m, ok := m["config"].(map[string]any)
m, ok = m["config"].(map[string]any)
if !ok {
return gperr.New("invalid config")
}
if err := utils.MapUnmarshalValidate(m, &cfg.Config); err != nil {
return err
}
return nil
return serialization.MapUnmarshalValidate(m, &cfg.Config)
}

View file

@ -7,10 +7,10 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/idlewatcher/provider"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
@ -73,13 +73,13 @@ var dummyHealthCheckConfig = &health.HealthCheckConfig{
}
var (
causeReload = gperr.New("reloaded")
causeContainerDestroy = gperr.New("container destroyed")
causeReload = gperr.New("reloaded") //nolint:errname
causeContainerDestroy = gperr.New("container destroyed") //nolint:errname
)
const reqTimeout = 3 * time.Second
// TODO: fix stream type
// TODO: fix stream type.
func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
cfg := r.IdlewatcherConfig()
key := cfg.Key()
@ -120,7 +120,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
return nil, err
}
w.provider = p
w.l = logging.With().
w.l = log.With().
Str("provider", providerType).
Str("container", cfg.ContainerName()).
Logger()

View file

@ -2,18 +2,17 @@ package jsonstore
import (
"encoding/json"
"maps"
"os"
"path/filepath"
"reflect"
"maps"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
type namespace string
@ -36,13 +35,15 @@ type store interface {
json.Unmarshaler
}
var stores = make(map[namespace]store)
var storesPath = common.DataDir
var (
stores = make(map[namespace]store)
storesPath = common.DataDir
)
func init() {
task.OnProgramExit("save_stores", func() {
if err := save(); err != nil {
logging.Error().Err(err).Msg("failed to save stores")
log.Error().Err(err).Msg("failed to save stores")
}
})
}
@ -54,20 +55,20 @@ func loadNS[T store](ns namespace) T {
file, err := os.Open(path)
if err != nil {
if !os.IsNotExist(err) {
logging.Err(err).
log.Err(err).
Str("path", path).
Msg("failed to load store")
}
} else {
defer file.Close()
if err := json.NewDecoder(file).Decode(&store); err != nil {
logging.Err(err).
log.Err(err).
Str("path", path).
Msg("failed to load store")
}
}
stores[ns] = store
logging.Debug().
log.Debug().
Str("namespace", string(ns)).
Str("path", path).
Msg("loaded store")
@ -77,7 +78,7 @@ func loadNS[T store](ns namespace) T {
func save() error {
errs := gperr.NewBuilder("failed to save data stores")
for ns, store := range stores {
if err := utils.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
if err := serialization.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
errs.Add(err)
}
}
@ -86,7 +87,7 @@ func save() error {
func Store[VT any](namespace namespace) MapStore[VT] {
if _, ok := stores[namespace]; ok {
logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
}
store := loadNS[*MapStore[VT]](namespace)
stores[namespace] = store
@ -95,7 +96,7 @@ func Store[VT any](namespace namespace) MapStore[VT] {
func Object[Ptr Initializer](namespace namespace) Ptr {
if _, ok := stores[namespace]; ok {
logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
}
obj := loadNS[*ObjectStore[Ptr]](namespace)
stores[namespace] = obj
@ -117,7 +118,7 @@ func (s *MapStore[VT]) UnmarshalJSON(data []byte) error {
}
s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp)))
for k, v := range tmp {
s.Map.Store(k, v)
s.Store(k, v)
}
return nil
}

View file

@ -8,8 +8,8 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
@ -83,6 +83,9 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
if err != nil {
return nil, err
}
if io == nil {
return nil, nil //nolint:nilnil
}
return NewAccessLoggerWithIO(parent, io, cfg), nil
}
@ -120,7 +123,7 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any
bufSize: MinBufferSize,
lineBufPool: synk.NewBytesPool(),
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
logger: logging.With().Str("file", writer.Name()).Logger(),
logger: log.With().Str("file", writer.Name()).Logger(),
}
l.supportRotate = unwrap[supportRotate](writer)
@ -181,7 +184,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
line := l.lineBufPool.Get()
defer l.lineBufPool.Put(line)
line = l.ACLFormatter.AppendACLLog(line, info, blocked)
line = l.AppendACLLog(line, info, blocked)
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
@ -194,7 +197,7 @@ func (l *AccessLogger) ShouldRotate() bool {
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
if !l.ShouldRotate() {
return nil, nil
return nil, nil //nolint:nilnil
}
l.writer.Flush()

View file

@ -4,7 +4,7 @@ import (
"time"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type (
@ -126,6 +126,6 @@ func DefaultACLLoggerConfig() *ACLLoggerConfig {
}
func init() {
utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
serialization.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
serialization.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
}

View file

@ -5,7 +5,7 @@ import (
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) {
expect.NoError(t, err)
var config RequestLoggerConfig
err = utils.MapUnmarshalValidate(parsed, &config)
err = serialization.MapUnmarshalValidate(parsed, &config)
expect.NoError(t, err)
expect.Equal(t, config.Format, FormatCombined)

View file

@ -7,7 +7,7 @@ import (
"path/filepath"
"sync"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils"
)
@ -35,20 +35,19 @@ func newFileIO(path string) (SupportRotate, error) {
if opened, ok := openedFiles[path]; ok {
opened.refCount.Add()
return opened, nil
} else {
// cannot open as O_APPEND as we need Seek and WriteAt
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("access log open error: %w", err)
}
if _, err := f.Seek(0, io.SeekEnd); err != nil {
return nil, fmt.Errorf("access log seek error: %w", err)
}
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[path] = file
go file.closeOnZero()
}
// cannot open as O_APPEND as we need Seek and WriteAt
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("access log open error: %w", err)
}
if _, err := f.Seek(0, io.SeekEnd); err != nil {
return nil, fmt.Errorf("access log seek error: %w", err)
}
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
openedFiles[path] = file
go file.closeOnZero()
return file, nil
}
@ -90,7 +89,7 @@ func (f *File) Close() error {
}
func (f *File) closeOnZero() {
defer logging.Debug().
defer log.Debug().
Str("path", f.path).
Msg("access log closed")

View file

@ -1,4 +1,3 @@
//nolint:zerologlint
package logging
import (
@ -10,6 +9,8 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/utils/strutils"
zerologlog "github.com/rs/zerolog/log"
)
var (
@ -61,22 +62,6 @@ func InitLogger(out ...io.Writer) {
log.SetOutput(writer)
log.SetPrefix("")
log.SetFlags(0)
zerolog.TimeFieldFormat = timeFmt
zerologlog.Logger = logger
}
func DiscardLogger() { zerolog.SetGlobalLevel(zerolog.Disabled) }
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
func GetLogger() *zerolog.Logger { return &logger }
func With() zerolog.Context { return logger.With() }
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
func Info() *zerolog.Event { return logger.Info() }
func Warn() *zerolog.Event { return logger.Warn() }
func Error() *zerolog.Event { return logger.Error() }
func Err(err error) *zerolog.Event { return logger.Err(err) }
func Debug() *zerolog.Event { return logger.Debug() }
func Fatal() *zerolog.Event { return logger.Fatal() }
func Panic() *zerolog.Event { return logger.Panic() }
func Trace() *zerolog.Event { return logger.Trace() }

View file

@ -2,8 +2,8 @@ package maxmind
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
)
type (
@ -28,6 +28,6 @@ func (cfg *Config) Validate() gperr.Error {
}
func (cfg *Config) Logger() *zerolog.Logger {
l := logging.With().Str("database", string(cfg.Database)).Logger()
l := log.With().Str("database", string(cfg.Database)).Logger()
return &l
}

View file

@ -10,8 +10,8 @@ import (
"sync"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/atomic"
)
@ -47,7 +47,7 @@ var initDataDirOnce sync.Once
func initDataDir() {
if err := os.MkdirAll(saveBaseDir, 0o755); err != nil {
logging.Error().Err(err).Msg("failed to create metrics data directory")
log.Error().Err(err).Msg("failed to create metrics data directory")
}
}
@ -65,7 +65,7 @@ func NewPoller[T any, AggregateT json.Marshaler](
}
func (p *Poller[T, AggregateT]) savePath() string {
return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name))
return filepath.Join(saveBaseDir, p.name+".json")
}
func (p *Poller[T, AggregateT]) load() error {
@ -135,13 +135,14 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
func (p *Poller[T, AggregateT]) Start() {
t := task.RootTask("poller." + p.name)
l := log.With().Str("name", p.name).Logger()
err := p.load()
if err != nil {
if !os.IsNotExist(err) {
logging.Error().Err(err).Msgf("failed to load last metrics data for %s", p.name)
l.Err(err).Msg("failed to load last metrics data")
}
} else {
logging.Debug().Msgf("Loaded last metrics data for %s, %d entries", p.name, p.period.Total())
l.Debug().Int("entries", p.period.Total()).Msgf("Loaded last metrics data")
}
go func() {
@ -154,11 +155,13 @@ func (p *Poller[T, AggregateT]) Start() {
gatherErrsTicker.Stop()
saveTicker.Stop()
p.save()
if err := p.save(); err != nil {
l.Err(err).Msg("failed to save metrics data")
}
t.Finish(nil)
}()
logging.Debug().Msgf("Starting poller %s with interval %s", p.name, pollInterval)
l.Debug().Dur("interval", pollInterval).Msg("Starting poller")
p.pollWithTimeout(t.Context())
@ -176,7 +179,7 @@ func (p *Poller[T, AggregateT]) Start() {
case <-gatherErrsTicker.C:
errs, ok := p.gatherErrs()
if ok {
logging.Error().Msg(errs)
log.Error().Msg(errs)
}
p.clearErrs()
}

View file

@ -8,6 +8,7 @@ import (
"syscall"
"time"
"github.com/rs/zerolog/log"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
@ -16,7 +17,6 @@ import (
"github.com/shirou/gopsutil/v4/warning"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/metrics/period"
)
@ -130,7 +130,7 @@ func getSystemInfo(ctx context.Context, lastResult *SystemInfo) (*SystemInfo, er
}
})
if allWarnings.HasError() {
logging.Warn().Msg(allWarnings.String())
log.Warn().Msg(allWarnings.String())
}
if allErrors.HasError() {
return nil, allErrors.Error()
@ -195,7 +195,7 @@ func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInf
if len(s.Disks) == 0 {
return errs.Error()
}
logging.Warn().Msg(errs.String())
log.Warn().Msg(errs.String())
}
return nil
}

View file

@ -6,7 +6,7 @@ import (
"errors"
"net/http"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
func WriteBody(w http.ResponseWriter, body []byte) {
@ -14,9 +14,9 @@ func WriteBody(w http.ResponseWriter, body []byte) {
switch {
case errors.Is(err, http.ErrHandlerTimeout),
errors.Is(err, context.DeadlineExceeded):
logging.Err(err).Msg("timeout writing body")
log.Err(err).Msg("timeout writing body")
default:
logging.Err(err).Msg("failed to write body")
log.Err(err).Msg("failed to write body")
}
}
}

View file

@ -9,11 +9,11 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
func warnNoMatchDomains() {
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
log.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
}
var warnNoMatchDomainOnce sync.Once

View file

@ -7,8 +7,8 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
"github.com/yusing/go-proxy/internal/task"
@ -47,7 +47,7 @@ func New(cfg *Config) *LoadBalancer {
lb := &LoadBalancer{
Config: new(Config),
pool: pool.New[Server]("loadbalancer." + cfg.Link),
l: logging.With().Str("name", cfg.Link).Logger(),
l: log.With().Str("name", cfg.Link).Logger(),
}
lb.UpdateConfigIfNeeded(cfg)
return lb

View file

@ -1,3 +1,3 @@
package types
type Weight uint16
type Weight int

View file

@ -4,14 +4,14 @@ import (
"net/http"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level).
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI)
return log.WithLevel(level). //nolint:zerologlint
Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI)
}
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }

View file

@ -4,8 +4,8 @@ import (
"net/http"
"text/template"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/auth"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
_ "embed"
@ -55,7 +55,7 @@ func PreRequest(p Provider, w http.ResponseWriter, r *http.Request) (proceed boo
"FormHTML": p.FormHTML(),
})
if err != nil {
logging.Error().Err(err).Msg("failed to execute captcha page")
log.Error().Err(err).Msg("failed to execute captcha page")
}
return false
}

View file

@ -7,7 +7,7 @@ import (
"github.com/go-playground/validator/v10"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@ -34,7 +34,7 @@ var (
)
func init() {
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
serialization.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
statusCode := fl.Field().Int()
return gphttp.IsStatusCodeValid(int(statusCode))
})
@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
for _, cidr := range wl.Allow {
if cidr.Contains(ip) {
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
}
}
if !allow {

View file

@ -8,7 +8,7 @@ import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
ExpectError(t, serialization.ErrValidationError, err)
})
t.Run("invalid cidr", func(t *testing.T) {
_, err := CIDRWhiteList.New(OptionsRaw{
@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
"status_code": 600,
"message": testMessage,
})
ExpectError(t, utils.ErrValidationError, err)
ExpectError(t, serialization.ErrValidationError, err)
})
}

View file

@ -1,6 +1,7 @@
package middleware
import (
"context"
"errors"
"fmt"
"io"
@ -9,8 +10,8 @@ import (
"sync"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/atomic"
"github.com/yusing/go-proxy/internal/utils/strutils"
@ -89,21 +90,29 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
)
if err != nil {
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
log.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
return nil
}
if len(cfCIDRs) == 0 {
logging.Warn().Msg("cloudflare CIDR range is empty")
log.Warn().Msg("cloudflare CIDR range is empty")
}
}
cfCIDRsLastUpdate.Store(time.Now())
logging.Info().Msg("cloudflare CIDR range updated")
log.Info().Msg("cloudflare CIDR range updated")
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
resp, err := http.Get(endpoint)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req) //nolint:gosec
if err != nil {
return err
}

View file

@ -8,7 +8,7 @@ import (
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
@ -32,7 +32,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
if ok {
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
log.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
@ -40,7 +40,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
log.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
return nil
}
@ -56,7 +56,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
filename := path[len(StaticFilePathPrefix):]
file, ok := errorpage.GetStaticFile(filename)
if !ok {
logging.Error().Msg("unable to load resource " + filename)
log.Error().Msg("unable to load resource " + filename)
return false
}
ext := filepath.Ext(filename)
@ -68,10 +68,10 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
case ".css":
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
log.Error().Msgf("unexpected file type %q for %s", ext, filename)
}
if _, err := w.Write(file); err != nil {
logging.Err(err).Msg("unable to write resource " + filename)
log.Err(err).Msg("unable to write resource " + filename)
http.Error(w, "Error page failure", http.StatusInternalServerError)
}
return true

View file

@ -6,9 +6,9 @@ import (
"path"
"sync"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional"
@ -48,7 +48,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
func loadContent() {
files, err := U.ListFiles(errPagesBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list error page resources")
log.Err(err).Msg("failed to list error page resources")
return
}
for _, file := range files {
@ -57,11 +57,11 @@ func loadContent() {
}
content, err := os.ReadFile(file)
if err != nil {
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
log.Warn().Err(err).Msgf("failed to read error page resource %s", file)
continue
}
file = path.Base(file)
logging.Info().Msgf("error page resource %s loaded", file)
log.Info().Msgf("error page resource %s loaded", file)
fileContentMap.Store(file, content)
}
}
@ -83,9 +83,9 @@ func watchDir() {
loadContent()
case events.ActionFileDeleted:
fileContentMap.Delete(filename)
logging.Warn().Msgf("error page resource %s deleted", filename)
log.Warn().Msgf("error page resource %s deleted", filename)
case events.ActionFileRenamed:
logging.Warn().Msgf("error page resource %s deleted", filename)
log.Warn().Msgf("error page resource %s deleted", filename)
fileContentMap.Delete(filename)
loadContent()
}

View file

@ -8,11 +8,11 @@ import (
"sort"
"strings"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type (
@ -87,7 +87,7 @@ func NewMiddleware[ImplType any]() *Middleware {
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.enableTrace()
logging.Trace().Msgf("middleware %s enabled trace", m.name)
log.Trace().Msgf("middleware %s enabled trace", m.name)
}
}
@ -118,14 +118,14 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
"priority": optsRaw["priority"],
"bypass": optsRaw["bypass"],
}
if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
if err := serialization.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
return err
}
optsRaw = maps.Clone(optsRaw)
for k := range commonOpts {
delete(optsRaw, k)
}
return utils.MapUnmarshalValidate(optsRaw, m.impl)
return serialization.MapUnmarshalValidate(optsRaw, m.impl)
}
func (m *Middleware) finalize() error {

View file

@ -3,9 +3,9 @@ package middleware
import (
"path"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -59,7 +59,7 @@ func LoadComposeFiles() {
errs := gperr.NewBuilder("middleware compile errors")
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logging.Err(err).Msg("failed to list middleware definitions")
log.Err(err).Msg("failed to list middleware definitions")
return
}
for _, defFile := range middlewareDefs {
@ -75,7 +75,7 @@ func LoadComposeFiles() {
continue
}
allMiddlewares[name] = m
logging.Info().
log.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")
@ -94,7 +94,7 @@ func LoadComposeFiles() {
continue
}
allMiddlewares[name] = m
logging.Info().
log.Info().
Str("src", path.Base(defFile)).
Str("name", name).
Msg("middleware loaded")

View file

@ -25,7 +25,7 @@ import (
"sync"
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
@ -138,7 +138,7 @@ func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper
panic("nil transport")
}
rp := &ReverseProxy{
Logger: logging.With().Str("name", name).Logger(),
Logger: log.With().Str("name", name).Logger(),
Transport: transport,
TargetName: name,
TargetURL: target,
@ -173,17 +173,17 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
case errors.Is(err, context.Canceled),
errors.Is(err, io.EOF),
errors.Is(err, context.DeadlineExceeded):
logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
log.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
default:
var recordErr tls.RecordHeaderError
if errors.As(err, &recordErr) {
logging.Error().
log.Error().
Str("url", reqURL).
Msgf(`scheme was likely misconfigured as https,
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
logging.Err(err).Msg("underlying error")
log.Err(err).Msg("underlying error")
} else {
logging.Err(err).Str("url", reqURL).Msg("http proxy error")
log.Err(err).Str("url", reqURL).Msg("http proxy error")
}
}
@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
/* trunk-ignore(golangci-lint/revive) */
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck
res, err := transport.RoundTrip(outreq)
@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
//nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
return
}
if err := brw.Flush(); err != nil {
/* trunk-ignore(golangci-lint/errorlint) */
//nolint:errorlint
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
return
}
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
/* trunk-ignore(golangci-lint/errcheck) */
//nolint:errcheck
bdp.Start()
}

View file

@ -9,14 +9,14 @@ import (
"github.com/quic-go/quic-go/http3"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/acl"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
)
type CertProvider interface {
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type Server struct {
@ -53,7 +53,7 @@ func StartServer(parent task.Parent, opt Options) (s *Server) {
func NewServer(opt Options) (s *Server) {
var httpSer, httpsSer *http.Server
logger := logging.With().Str("server", opt.Name).Logger()
logger := log.With().Str("server", opt.Name).Logger()
certAvailable := false
if opt.CertProvider != nil {

View file

@ -2,7 +2,7 @@ package notif
import (
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type NotificationConfig struct {
@ -46,5 +46,5 @@ func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err gperr.Error)
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
}
return utils.MapUnmarshalValidate(m, cfg.Provider)
return serialization.MapUnmarshalValidate(m, cfg.Provider)
}

View file

@ -4,7 +4,7 @@ import (
"net/http"
"testing"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -181,7 +181,7 @@ func TestNotificationConfig(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
var cfg NotificationConfig
provider := tt.cfg["provider"]
err := utils.MapUnmarshalValidate(tt.cfg, &cfg)
err := serialization.MapUnmarshalValidate(tt.cfg, &cfg)
if tt.wantErr {
ExpectHasError(t, err)
} else {

View file

@ -8,14 +8,14 @@ import (
"net/http"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
)
type (
Provider interface {
utils.CustomValidator
serialization.CustomValidator
GetName() string
GetURL() string
@ -73,7 +73,7 @@ func (msg *LogMessage) notify(ctx context.Context, provider Provider) error {
switch resp.StatusCode {
case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent:
body, _ := io.ReadAll(resp.Body)
logging.Debug().
log.Debug().
Str("provider", provider.GetName()).
Str("url", provider.GetURL()).
Str("status", resp.Status).

View file

@ -7,12 +7,12 @@ import (
"github.com/docker/docker/client"
"github.com/goccy/go-yaml"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/route"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/internal/watcher"
)
@ -36,7 +36,7 @@ func DockerProviderImpl(name, dockerHost string) ProviderImpl {
return &DockerProvider{
name,
dockerHost,
logging.With().Str("type", "docker").Str("name", name).Logger(),
log.With().Str("type", "docker").Str("name", name).Logger(),
}
}
@ -106,7 +106,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
// Always non-nil.
func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) {
if !container.IsExplicit && p.IsExplicitOnly() {
return nil, nil
return make(route.Routes, 0), nil
}
routes := make(route.Routes, len(container.Aliases))
@ -180,7 +180,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
}
// deserialize map into entry object
err := U.MapUnmarshalValidate(entryMap, r)
err := serialization.MapUnmarshalValidate(entryMap, r)
if err != nil {
errs.Add(err.Subject(alias))
} else {
@ -189,7 +189,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
}
if wildcardProps != nil {
for _, re := range routes {
if err := U.MapUnmarshalValidate(wildcardProps, re); err != nil {
if err := serialization.MapUnmarshalValidate(wildcardProps, re); err != nil {
errs.Add(err.Subject(docker.WildcardAlias))
break
}

View file

@ -6,11 +6,11 @@ import (
"strings"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
W "github.com/yusing/go-proxy/internal/watcher"
)
@ -24,7 +24,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
impl := &FileProvider{
fileName: filename,
path: path.Join(common.ConfigBasePath, filename),
l: logging.With().Str("type", "file").Str("name", filename).Logger(),
l: log.With().Str("type", "file").Str("name", filename).Logger(),
}
_, err := os.Stat(impl.path)
if err != nil {
@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
}
func validate(data []byte) (routes route.Routes, err gperr.Error) {
err = utils.UnmarshalValidateYAML(data, &routes)
err = serialization.UnmarshalValidateYAML(data, &routes)
return
}

View file

@ -7,12 +7,12 @@ import (
"time"
"github.com/docker/docker/api/types/container"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/homepage"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
"github.com/yusing/go-proxy/internal/logging"
netutils "github.com/yusing/go-proxy/internal/net"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxmox"
@ -116,7 +116,7 @@ func (r *Route) Validate() gperr.Error {
Subject(containerName)
}
l := logging.With().Str("container", containerName).Logger()
l := log.With().Str("container", containerName).Logger()
l.Info().Msg("checking if container is running")
running, err := node.LXCIsRunning(ctx, vmid)

View file

@ -3,7 +3,7 @@ package rules
import (
"testing"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -28,7 +28,7 @@ func TestParseRule(t *testing.T) {
var rules struct {
Rules Rules
}
err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules)
err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
ExpectNoError(t, err)
ExpectEqual(t, len(rules.Rules), len(test))
ExpectEqual(t, rules.Rules[0].Name, "test")

View file

@ -5,9 +5,9 @@ import (
"errors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/idlewatcher"
"github.com/yusing/go-proxy/internal/logging"
net "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
@ -32,7 +32,7 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
// TODO: support non-coherent scheme
return &StreamRoute{
Route: base,
l: logging.With().
l: log.With().
Str("type", string(base.Scheme)).
Str("name", base.Name()).
Logger(),

View file

@ -6,7 +6,7 @@ import (
. "github.com/yusing/go-proxy/internal/route"
route "github.com/yusing/go-proxy/internal/route/types"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/serialization"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
@ -40,7 +40,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
cfg := Route{}
tt.input["host"] = "internal"
err := utils.MapUnmarshalValidate(tt.input, &cfg)
err := serialization.MapUnmarshalValidate(tt.input, &cfg)
if err != nil {
expect.NoError(t, err)
}

View file

@ -6,8 +6,8 @@ import (
"net"
"sync"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@ -88,7 +88,7 @@ func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) {
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
if err == nil {
logging.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
}
return
}
@ -102,7 +102,7 @@ func (conn *UDPConn) read() (err error) {
conn.buf.oobn = 0
}
if err == nil {
logging.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
}
return
}
@ -110,7 +110,7 @@ func (conn *UDPConn) read() (err error) {
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
if err == nil {
logging.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
}
return
}
@ -120,12 +120,12 @@ func (conn *UDPConn) write() (err error) {
case *net.UDPConn:
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
if err == nil {
logging.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
}
default:
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
if err == nil {
logging.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
}
}

View file

@ -1,4 +1,4 @@
package utils
package serialization
import (
"encoding/json"
@ -12,7 +12,9 @@ import (
"github.com/go-playground/validator/v10"
"github.com/goccy/go-yaml"
"github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -40,14 +42,14 @@ var (
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
var defaultValues = xsync.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
t := reflect.TypeFor[T]()
if t.Kind() == reflect.Ptr {
panic("pointer of pointer")
}
if defaultValues.Has(t) {
if _, ok := defaultValues.Load(t); ok {
panic("default value for " + t.String() + " already registered")
}
defaultValues.Store(t, func() any { return factory() })
@ -259,7 +261,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool)
errs.Add(err.Subject(k))
}
} else {
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping))))
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(utils.NearestField(k, mapping))))
}
}
if hasValidateTag && checkValidateTag {

View file

@ -1,4 +1,4 @@
package utils
package serialization
import (
"reflect"

View file

@ -1,4 +1,4 @@
package utils
package serialization
import (
"github.com/go-playground/validator/v10"

View file

@ -7,9 +7,9 @@ import (
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -116,14 +116,14 @@ func (t *Task) Finish(reason any) {
func (t *Task) finish(reason any) {
t.cancel(fmtCause(reason))
if !waitWithTimeout(t.childrenDone) {
logging.Debug().
log.Debug().
Str("task", t.name).
Strs("subtasks", t.listChildren()).
Msg("Timeout waiting for subtasks to finish")
}
go t.runCallbacks()
if !waitWithTimeout(t.callbacksDone) {
logging.Debug().
log.Debug().
Str("task", t.name).
Strs("callbacks", t.listCallbacks()).
Msg("Timeout waiting for callbacks to finish")
@ -134,7 +134,7 @@ func (t *Task) finish(reason any) {
}
t.parent.subChildCount()
allTasks.Remove(t)
logging.Trace().Msg("task " + t.name + " finished")
log.Trace().Msg("task " + t.name + " finished")
}
// Subtask returns a new subtask with the given name, derived from the parent's context.
@ -166,7 +166,7 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
}()
}
logging.Trace().Msg("task " + child.name + " started")
log.Trace().Msg("task " + child.name + " started")
return child
}
@ -189,7 +189,7 @@ func (t *Task) MarshalText() ([]byte, error) {
func (t *Task) invokeWithRecover(fn func(), caller string) {
defer func() {
if err := recover(); err != nil {
logging.Error().
log.Error().
Interface("err", err).
Msg("panic in task " + t.name + "." + caller)
if common.IsDebug {

View file

@ -9,7 +9,7 @@ import (
"syscall"
"time"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
@ -68,10 +68,10 @@ func GracefulShutdown(timeout time.Duration) (err error) {
case <-after:
b, err := json.Marshal(DebugTaskList())
if err != nil {
logging.Warn().Err(err).Msg("failed to marshal tasks")
log.Warn().Err(err).Msg("failed to marshal tasks")
return context.DeadlineExceeded
}
logging.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
return context.DeadlineExceeded
}
}
@ -87,6 +87,6 @@ func WaitExit(shutdownTimeout int) {
<-sig
// gracefully shutdown
logging.Info().Msg("shutting down")
log.Info().Msg("shutting down")
_ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout))
}

21
internal/utils/go.mod Normal file
View file

@ -0,0 +1,21 @@
module github.com/yusing/go-proxy/internal/utils
go 1.24.3
require (
github.com/goccy/go-yaml v1.17.1
github.com/puzpuzpuz/xsync/v4 v4.1.0
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
go.uber.org/atomic v1.11.0
golang.org/x/text v0.25.0
)
require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
golang.org/x/sys v0.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

36
internal/utils/go.sum Normal file
View file

@ -0,0 +1,36 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY=
github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v4 v4.1.0 h1:x9eHRl4QhZFIPJ17yl4KKW9xLyVWbb3/Yq4SXpjF71U=
github.com/puzpuzpuz/xsync/v4 v4.1.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -8,7 +8,6 @@ import (
"sync"
"syscall"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils/synk"
)
@ -91,20 +90,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
}
}
func (p BidirectionalPipe) Start() gperr.Error {
func (p BidirectionalPipe) Start() error {
var wg sync.WaitGroup
wg.Add(2)
b := gperr.NewBuilder("bidirectional pipe error")
var srcErr, dstErr error
go func() {
b.Add(p.pSrcDst.Start())
srcErr = p.pSrcDst.Start()
wg.Done()
}()
go func() {
b.Add(p.pDstSrc.Start())
dstErr = p.pDstSrc.Start()
wg.Done()
}()
wg.Wait()
return b.Error()
return errors.Join(srcErr, dstErr)
}
type httpFlusher interface {
@ -143,30 +142,18 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
wCloser, wCanClose := dst.Writer.(io.Closer)
rCloser, rCanClose := src.Reader.(io.Closer)
if wCanClose || rCanClose {
if src.ctx == dst.ctx {
go func() {
<-src.ctx.Done()
if wCanClose {
wCloser.Close()
}
if rCanClose {
rCloser.Close()
}
}()
} else {
if wCloser != nil {
go func() {
<-src.ctx.Done()
wCloser.Close()
}()
go func() {
select {
case <-src.ctx.Done():
case <-dst.ctx.Done():
}
if rCloser != nil {
go func() {
<-dst.ctx.Done()
rCloser.Close()
}()
if rCanClose {
defer rCloser.Close()
}
}
if wCanClose {
defer wCloser.Close()
}
}()
}
flusher := getHTTPFlusher(dst.Writer)
canFlush := flusher != nil

View file

@ -4,7 +4,7 @@ import (
"sort"
"github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
type (
@ -29,12 +29,12 @@ func (p Pool[T]) Name() string {
func (p Pool[T]) Add(obj T) {
p.checkExists(obj.Key())
p.m.Store(obj.Key(), obj)
logging.Info().Msgf("%s: added %s", p.name, obj.Name())
log.Info().Msgf("%s: added %s", p.name, obj.Name())
}
func (p Pool[T]) Del(obj T) {
p.m.Delete(obj.Key())
logging.Info().Msgf("%s: removed %s", p.name, obj.Name())
log.Info().Msgf("%s: removed %s", p.name, obj.Name())
}
func (p Pool[T]) Get(key string) (T, bool) {

View file

@ -5,11 +5,11 @@ package pool
import (
"runtime/debug"
"github.com/yusing/go-proxy/internal/logging"
"github.com/rs/zerolog/log"
)
func (p Pool[T]) checkExists(key string) {
if _, ok := p.m.Load(key); ok {
logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
log.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
}
}

View file

@ -1,16 +1,35 @@
package synk
import (
"os"
"os/signal"
"sync/atomic"
"time"
"github.com/yusing/go-proxy/internal/logging"
"runtime"
"unsafe"
)
type weakBuf = unsafe.Pointer
func makeWeak(b *[]byte) weakBuf {
ptr := runtime_registerWeakPointer(unsafe.Pointer(b))
runtime.KeepAlive(ptr)
addCleanup(b, addGCed, cap(*b))
return weakBuf(ptr)
}
func getBufFromWeak(w weakBuf) []byte {
ptr := (*[]byte)(runtime_makeStrongFromWeak(w))
if ptr == nil {
return nil
}
return *ptr
}
//go:linkname runtime_registerWeakPointer weak.runtime_registerWeakPointer
func runtime_registerWeakPointer(unsafe.Pointer) unsafe.Pointer
//go:linkname runtime_makeStrongFromWeak weak.runtime_makeStrongFromWeak
func runtime_makeStrongFromWeak(unsafe.Pointer) unsafe.Pointer
type BytesPool struct {
pool chan []byte
pool chan weakBuf
initSize int
}
@ -22,19 +41,15 @@ const (
const (
InPoolLimit = 32 * mb
DefaultInitBytes = 32 * kb
PoolThreshold = 64 * kb
DefaultInitBytes = 4 * kb
PoolThreshold = 256 * kb
DropThresholdHigh = 4 * mb
PoolSize = InPoolLimit / PoolThreshold
CleanupInterval = 5 * time.Second
MaxDropsPerCycle = 10
MaxChecksPerCycle = 100
)
var bytesPool = &BytesPool{
pool: make(chan []byte, PoolSize),
pool: make(chan weakBuf, PoolSize),
initSize: DefaultInitBytes,
}
@ -43,12 +58,18 @@ func NewBytesPool() *BytesPool {
}
func (p *BytesPool) Get() []byte {
select {
case b := <-p.pool:
subInPoolSize(int64(cap(b)))
return b
default:
return make([]byte, 0, p.initSize)
for {
select {
case bWeak := <-p.pool:
bPtr := getBufFromWeak(bWeak)
if bPtr == nil {
continue
}
addReused(cap(bPtr))
return bPtr
default:
return make([]byte, 0, p.initSize)
}
}
}
@ -56,90 +77,43 @@ func (p *BytesPool) GetSized(size int) []byte {
if size <= PoolThreshold {
return make([]byte, size)
}
select {
case b := <-p.pool:
if size <= cap(b) {
subInPoolSize(int64(cap(b)))
return b[:size]
}
for {
select {
case p.pool <- b:
addInPoolSize(int64(cap(b)))
case bWeak := <-p.pool:
bPtr := getBufFromWeak(bWeak)
if bPtr == nil {
continue
}
capB := cap(bPtr)
if capB >= size {
addReused(capB)
return (bPtr)[:size]
}
select {
case p.pool <- bWeak:
default:
// just drop it
}
default:
}
default:
return make([]byte, size)
}
return make([]byte, size)
}
func (p *BytesPool) Put(b []byte) {
size := cap(b)
if size > DropThresholdHigh || poolFull() {
if size <= PoolThreshold || size > DropThresholdHigh {
return
}
b = b[:0]
w := makeWeak(&b)
select {
case p.pool <- b:
addInPoolSize(int64(size))
return
case p.pool <- w:
default:
// just drop it
}
}
var inPoolSize int64
func addInPoolSize(size int64) {
atomic.AddInt64(&inPoolSize, size)
}
func subInPoolSize(size int64) {
atomic.AddInt64(&inPoolSize, -size)
}
func init() {
// Periodically drop some buffers to prevent excessive memory usage
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
cleanupTicker := time.NewTicker(CleanupInterval)
defer cleanupTicker.Stop()
for {
select {
case <-cleanupTicker.C:
dropBuffers()
case <-sigCh:
return
}
}
}()
}
func poolFull() bool {
return atomic.LoadInt64(&inPoolSize) >= InPoolLimit
}
// dropBuffers removes excess buffers from the pool when it grows too large.
func dropBuffers() {
// Check if pool has more than a threshold of buffers
count := 0
droppedSize := 0
checks := 0
for count < MaxDropsPerCycle && checks < MaxChecksPerCycle && atomic.LoadInt64(&inPoolSize) > InPoolLimit*2/3 {
select {
case b := <-bytesPool.pool:
n := cap(b)
subInPoolSize(int64(n))
droppedSize += n
count++
default:
time.Sleep(10 * time.Millisecond)
}
checks++
}
if count > 0 {
logging.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool")
}
initPoolStats()
}

View file

@ -4,7 +4,7 @@ import (
"testing"
)
var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024}
var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024}
func BenchmarkBytesPool_GetSmall(b *testing.B) {
for b.Loop() {

View file

@ -0,0 +1,55 @@
//go:build !production
package synk
import (
"os"
"os/signal"
"runtime"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var (
numReused, sizeReused uint64
numGCed, sizeGCed uint64
)
func addReused(size int) {
atomic.AddUint64(&numReused, 1)
atomic.AddUint64(&sizeReused, uint64(size))
}
func addGCed(size int) {
atomic.AddUint64(&numGCed, 1)
atomic.AddUint64(&sizeGCed, uint64(size))
}
var addCleanup = runtime.AddCleanup[[]byte, int]
func initPoolStats() {
go func() {
statsTicker := time.NewTicker(5 * time.Second)
defer statsTicker.Stop()
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt)
for {
select {
case <-sig:
return
case <-statsTicker.C:
log.Info().
Uint64("numReused", atomic.LoadUint64(&numReused)).
Str("sizeReused", strutils.FormatByteSize(atomic.LoadUint64(&sizeReused))).
Uint64("numGCed", atomic.LoadUint64(&numGCed)).
Str("sizeGCed", strutils.FormatByteSize(atomic.LoadUint64(&sizeGCed))).
Msg("bytes pool stats")
}
}
}()
}

View file

@ -0,0 +1,8 @@
//go:build production
package synk
func addReused(size int) {}
func addGCed(size int) {}
func initPoolStats() {}
func addCleanup(ptr *[]byte, cleanup func(int), arg int) {}

View file

@ -2,14 +2,16 @@ package expect
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
)
var isTest = strings.HasSuffix(os.Args[0], ".test")
func init() {
if common.IsTest {
if isTest {
// force verbose output
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}

View file

@ -1,19 +1,11 @@
package expect
import (
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/go-proxy/internal/common"
)
func init() {
if common.IsTest {
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
}
}
func ExpectNoError(t *testing.T, err error) {
t.Helper()
require.NoError(t, err)

View file

@ -3,7 +3,6 @@ package utils
import (
"time"
"github.com/yusing/go-proxy/internal/task"
"go.uber.org/atomic"
)
@ -38,8 +37,6 @@ func init() {
go func() {
for {
select {
case <-task.RootContext().Done():
return
case <-timeNowTicker.C:
shouldCallTimeNow.Store(true)
}

View file

@ -8,8 +8,8 @@ import (
"github.com/fsnotify/fsnotify"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/events"
)
@ -41,13 +41,13 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
//! subdirectories are not watched
w, err := fsnotify.NewWatcher()
if err != nil {
logging.Panic().Err(err).Msg("unable to create fs watcher")
log.Panic().Err(err).Msg("unable to create fs watcher")
}
if err = w.Add(dirPath); err != nil {
logging.Panic().Err(err).Msg("unable to create fs watcher")
log.Panic().Err(err).Msg("unable to create fs watcher")
}
helper := &DirWatcher{
Logger: logging.With().
Logger: log.With().
Str("type", "dir").
Str("path", dirPath).
Logger(),

View file

@ -8,9 +8,9 @@ import (
docker_events "github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/watcher/events"
)
@ -81,7 +81,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
}()
cEventCh, cErrCh := client.Events(ctx, options)
defer logging.Debug().Str("host", client.Address()).Msg("docker watcher closed")
defer log.Debug().Str("host", client.Address()).Msg("docker watcher closed")
for {
select {
case <-ctx.Done():
@ -153,7 +153,7 @@ func checkConnection(ctx context.Context, client *docker.SharedClient) bool {
defer cancel()
err := client.CheckConnection(ctx)
if err != nil {
logging.Debug().Err(err).Msg("docker watcher: connection failed")
log.Debug().Err(err).Msg("docker watcher: connection failed")
return false
}
return true

View file

@ -7,9 +7,9 @@ import (
"net/url"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/notif"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/task"
@ -48,7 +48,7 @@ func NewMonitor(r routes.Route) health.HealthMonCheck {
case routes.StreamRoute:
mon = NewRawHealthMonitor(&r.TargetURL().URL, r.HealthCheckConfig())
default:
logging.Panic().Msgf("unexpected route type: %T", r)
log.Panic().Msgf("unexpected route type: %T", r)
}
}
if r.IsDocker() {
@ -91,7 +91,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
mon.task = parent.Subtask("health_monitor", true)
go func() {
logger := logging.With().Str("name", mon.service).Logger()
logger := log.With().Str("name", mon.service).Logger()
defer func() {
if mon.status.Load() != health.StatusError {
@ -221,7 +221,7 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
}
func (mon *monitor) checkUpdateHealth() error {
logger := logging.With().Str("name", mon.Name()).Logger()
logger := log.With().Str("name", mon.Name()).Logger()
result, err := mon.checkHealth()
var lastStatus health.Status

View file

@ -18,7 +18,7 @@ func GetLastVersion() Version {
func GetVersionHTTPHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(GetVersion().String()))
fmt.Fprint(w, GetVersion().String())
}
}
@ -34,16 +34,16 @@ func init() {
// lastVersion = ParseVersion(lastVersionStr)
// }
// if err != nil && !os.IsNotExist(err) {
// logging.Warn().Err(err).Msg("failed to read version file")
// log.Warn().Err(err).Msg("failed to read version file")
// return
// }
// if err := f.Truncate(0); err != nil {
// logging.Warn().Err(err).Msg("failed to truncate version file")
// log.Warn().Err(err).Msg("failed to truncate version file")
// return
// }
// _, err = f.WriteString(version)
// if err != nil {
// logging.Warn().Err(err).Msg("failed to save version file")
// log.Warn().Err(err).Msg("failed to save version file")
// return
// }
}

View file

@ -2,4 +2,19 @@ module github.com/yusing/go-proxy/socketproxy
go 1.24.3
require github.com/gorilla/mux v1.8.1
replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
require (
github.com/gorilla/mux v1.8.1
github.com/yusing/go-proxy/internal/utils v0.0.0-00010101000000-000000000000
golang.org/x/net v0.40.0
)
require (
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/rs/zerolog v1.34.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0 // indirect
)

View file

@ -1,2 +1,34 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -4,30 +4,33 @@ import (
"context"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/yusing/go-proxy/socketproxy/pkg/reverseproxy"
)
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", DockerSocket)
func dialDockerSocket(socket string) func(ctx context.Context, _, _ string) (net.Conn, error) {
return func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", socket)
}
}
var DockerSocketHandler = dockerSocketHandler
func dockerSocketHandler() http.HandlerFunc {
rp := &httputil.ReverseProxy{
func dockerSocketHandler(socket string) http.HandlerFunc {
rp := &reverseproxy.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = "http"
req.URL.Host = "api.moby.localhost"
req.RequestURI = req.URL.String()
},
Transport: &http.Transport{
DialContext: dialDockerSocket,
DialContext: dialDockerSocket(socket),
DisableCompression: true,
},
}
@ -41,7 +44,7 @@ func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
func NewHandler() http.Handler {
r := mux.NewRouter()
socketHandler := DockerSocketHandler()
socketHandler := DockerSocketHandler(DockerSocket)
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"

View file

@ -9,7 +9,7 @@ import (
. "github.com/yusing/go-proxy/socketproxy/pkg"
)
func mockDockerSocketHandler() http.HandlerFunc {
func mockDockerSocketHandler(_ string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("mock docker response"))

View file

@ -0,0 +1,367 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// License URL: https://cs.opensource.google/go/go/+/master:LICENSE
// HTTP reverse proxy handler
package reverseproxy
import (
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httptrace"
"net/textproto"
"strings"
"sync"
"github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Director is a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
//
// By default, the X-Forwarded-For header is set to the
// value of the client IP address. If an X-Forwarded-For
// header already exists, the client IP is appended to the
// existing values. As a special case, if the header
// exists in the Request.Header map but has a nil value
// (such as when set by the Director func), the X-Forwarded-For
// header is not modified.
//
// To prevent IP spoofing, be sure to delete any pre-existing
// X-Forwarded-For header coming from the client or
// an untrusted proxy.
//
// Hop-by-hop headers are removed from the request after
// Director returns, which can remove headers added by
// Director. Use a Rewrite function instead to ensure
// modifications to the request are preserved.
//
// Unparsable query parameters are removed from the outbound
// request if Request.Form is set after Director returns.
//
// At most one of Rewrite or Director may be set.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
// ErrorHandler is an optional function that handles errors
// reaching the backend or errors from ModifyResponse.
//
// If nil, the default is to log the provided error and return
// a 502 Status Bad Gateway response.
ErrorHandler func(http.ResponseWriter, *http.Request, error)
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusBadGateway)
}
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
if p.ErrorHandler != nil {
return p.ErrorHandler
}
return p.defaultErrorHandler
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
ctx := req.Context()
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
p.Director(outreq)
outreq.Close = false
reqUpType := upgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
return
}
req.Header.Del("Forwarded")
removeHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
// don't send the default Go HTTP client User-Agent.
outreq.Header.Set("User-Agent", "")
}
var (
roundTripMutex sync.Mutex
roundTripDone bool
)
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
roundTripMutex.Lock()
defer roundTripMutex.Unlock()
if roundTripDone {
// If RoundTrip has returned, don't try to further modify
// the ResponseWriter's header map.
return nil
}
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
clear(h)
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
roundTripMutex.Lock()
roundTripDone = true
roundTripMutex.Unlock()
if err != nil {
p.getErrorHandler()(rw, outreq, err)
return
}
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
p.handleUpgradeResponse(rw, outreq, res)
return
}
removeHopByHopHeaders(res.Header)
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
err = utils.CopyCloseWithContext(ctx, rw, res.Body)
if err != nil {
if !errors.Is(err, context.Canceled) {
p.getErrorHandler()(rw, req, err)
}
return
}
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
http.NewResponseController(rw).Flush()
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
// removeHopByHopHeaders removes hop-by-hop headers.
func removeHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for sf := range strings.SplitSeq(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
func (p *ReverseProxy) logf(format string, args ...any) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
return
}
if !strings.EqualFold(reqUpType, resUpType) {
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
return
}
rc := http.NewResponseController(rw)
conn, brw, hijackErr := rc.Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
if hijackErr != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
return
}
defer conn.Close()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
return
}
if err := brw.Flush(); err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
return
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
<-errc
}
// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
user, backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
_, err := io.Copy(c.user, c.backend)
errc <- err
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.Copy(c.backend, c.user)
errc <- err
}
func IsPrint(s string) bool {
for _, r := range s {
if r < ' ' || r > '~' {
return false
}
}
return true
}