mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-30 00:22:35 +02:00
fix: optimize memory usage, fix agent and code refactor (#118)
Some checks are pending
Docker Image CI (socket-proxy) / build (push) Waiting to run
Some checks are pending
Docker Image CI (socket-proxy) / build (push) Waiting to run
* refactor: simplify io code and make utils module independent * fix(docker): agent and socket-proxy docker event flushing with modified reverse proxy handler * refactor: remove unused code * refactor: remove the use of logging module in most code * refactor: streamline domain mismatch check in certState function * tweak: use ecdsa p-256 for autocert * fix(tests): update health check tests for invalid host and add case for port in host * feat(acme): custom acme directory * refactor: code refactor and improved context and error handling * tweak: optimize memory usage under load * fix(oidc): restore old user matching behavior * docs: add ChatGPT assistant to README --------- Co-authored-by: yusing <yusing@6uo.me>
This commit is contained in:
parent
ff08c40403
commit
4a8bd48ad5
98 changed files with 1549 additions and 555 deletions
|
@ -2,15 +2,16 @@ version: "2"
|
|||
linters:
|
||||
default: all
|
||||
disable:
|
||||
- bodyclose
|
||||
# - bodyclose
|
||||
- containedctx
|
||||
- contextcheck
|
||||
# - contextcheck
|
||||
- cyclop
|
||||
- depguard
|
||||
- dupl
|
||||
# - dupl
|
||||
- err113
|
||||
- exhaustive
|
||||
- exhaustruct
|
||||
- funcorder
|
||||
- forcetypeassert
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
|
@ -18,7 +19,6 @@ linters:
|
|||
- goconst
|
||||
- gocyclo
|
||||
- gomoddirectives
|
||||
- gosec
|
||||
- gosmopolitan
|
||||
- ireturn
|
||||
- lll
|
||||
|
@ -27,12 +27,10 @@ linters:
|
|||
- mnd
|
||||
- nakedret
|
||||
- nestif
|
||||
- nilnil
|
||||
- nlreturn
|
||||
- noctx
|
||||
- nonamedreturns
|
||||
- paralleltest
|
||||
- prealloc
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- tagliatelle
|
||||
|
|
|
@ -21,7 +21,7 @@ lint:
|
|||
- markdownlint
|
||||
- yamllint
|
||||
enabled:
|
||||
- checkov@3.2.416
|
||||
- checkov@3.2.432
|
||||
- golangci-lint2@2.1.6
|
||||
- hadolint@2.12.1-beta
|
||||
- actionlint@1.7.7
|
||||
|
@ -32,7 +32,7 @@ lint:
|
|||
- prettier@3.5.3
|
||||
- shellcheck@0.10.0
|
||||
- shfmt@3.6.0
|
||||
- trufflehog@3.88.29
|
||||
- trufflehog@3.88.33
|
||||
actions:
|
||||
disabled:
|
||||
- trunk-announce
|
||||
|
|
|
@ -16,6 +16,8 @@ A lightweight, simple, and performant reverse proxy with WebUI.
|
|||
|
||||
<h5>EN | <a href="README_CHT.md">中文</a></h5>
|
||||
|
||||
Have questions? Ask [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)! (Thanks to [@ismesid](https://github.com/arevindh))
|
||||
|
||||
<img src="screenshots/webui.jpg" style="max-width: 650">
|
||||
|
||||
</div>
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
<h5><a href="README.md">EN</a> | 中文</h5>
|
||||
|
||||
有疑問? 問 [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)!(鳴謝 [@ismesid](https://github.com/arevindh))
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4bb371f4-6e4c-425c-89b2-b9e962bdd46f" style="max-width: 650">
|
||||
|
||||
</div>
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
"github.com/yusing/go-proxy/agent/pkg/server"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||
httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -14,6 +17,12 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
writer := zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
TimeFormat: "01-02 15:04",
|
||||
}
|
||||
zerolog.TimeFieldFormat = writer.TimeFormat
|
||||
log.Logger = zerolog.New(writer).Level(zerolog.InfoLevel).With().Timestamp().Logger()
|
||||
ca := &agent.PEMPair{}
|
||||
err := ca.Load(env.AgentCACert)
|
||||
if err != nil {
|
||||
|
@ -34,11 +43,11 @@ func main() {
|
|||
gperr.LogFatal("init SSL error", err)
|
||||
}
|
||||
|
||||
logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
|
||||
logging.Info().Msgf("Agent name: %s", env.AgentName)
|
||||
logging.Info().Msgf("Agent port: %d", env.AgentPort)
|
||||
log.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
|
||||
log.Info().Msgf("Agent name: %s", env.AgentName)
|
||||
log.Info().Msgf("Agent port: %d", env.AgentPort)
|
||||
|
||||
logging.Info().Msg(`
|
||||
log.Info().Msg(`
|
||||
Tips:
|
||||
1. To change the agent name, you can set the AGENT_NAME environment variable.
|
||||
2. To change the agent port, you can set the AGENT_PORT environment variable.
|
||||
|
@ -54,7 +63,7 @@ Tips:
|
|||
server.StartAgentServer(t, opts)
|
||||
|
||||
if socketproxy.ListenAddr != "" {
|
||||
logging.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
|
||||
log.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
|
||||
opts := httpServer.Options{
|
||||
Name: "docker",
|
||||
HTTPAddr: socketproxy.ListenAddr,
|
||||
|
|
|
@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy => ..
|
|||
|
||||
replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy
|
||||
|
||||
replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
|
||||
|
||||
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
|
||||
|
||||
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250523121925-f87c3159e327
|
||||
|
@ -15,6 +17,7 @@ require (
|
|||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
|
||||
github.com/yusing/go-proxy/internal/utils v0.0.0
|
||||
github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/certs"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
)
|
||||
|
||||
|
@ -115,7 +115,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
|||
|
||||
cfg.name = string(name)
|
||||
|
||||
cfg.l = logging.With().Str("agent", cfg.name).Logger()
|
||||
cfg.l = log.With().Str("agent", cfg.name).Logger()
|
||||
|
||||
// check agent version
|
||||
agentVersionBytes, _, err := cfg.Fetch(ctx, EndpointVersion)
|
||||
|
@ -127,10 +127,10 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
|||
agentVersion := pkg.ParseVersion(cfg.version)
|
||||
|
||||
if serverVersion.IsNewerMajorThan(agentVersion) {
|
||||
logging.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
|
||||
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
|
||||
}
|
||||
|
||||
logging.Info().Msgf("agent %q initialized", cfg.name)
|
||||
log.Info().Msgf("agent %q initialized", cfg.name)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -172,9 +172,9 @@ func TestCheckHealthTCPUDP(t *testing.T) {
|
|||
{
|
||||
name: "InvalidHost",
|
||||
scheme: "tcp",
|
||||
host: "invalid",
|
||||
host: "",
|
||||
port: 8080,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedHealthy: false,
|
||||
},
|
||||
{
|
||||
|
@ -188,9 +188,17 @@ func TestCheckHealthTCPUDP(t *testing.T) {
|
|||
{
|
||||
name: "InvalidHost",
|
||||
scheme: "udp",
|
||||
host: "invalid",
|
||||
host: "",
|
||||
port: 8080,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedHealthy: false,
|
||||
},
|
||||
{
|
||||
name: "Port in both host and port",
|
||||
scheme: "tcp",
|
||||
host: "localhost:1234",
|
||||
port: 1234,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedHealthy: false,
|
||||
},
|
||||
}
|
||||
|
@ -208,9 +216,11 @@ func TestCheckHealthTCPUDP(t *testing.T) {
|
|||
|
||||
require.Equal(t, recorder.Code, tt.expectedStatus)
|
||||
|
||||
var result health.HealthCheckResult
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
|
||||
require.Equal(t, result.Healthy, tt.expectedHealthy)
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
var result health.HealthCheckResult
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
|
||||
require.Equal(t, result.Healthy, tt.expectedHealthy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||
"github.com/yusing/go-proxy/pkg"
|
||||
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
|
||||
)
|
||||
|
||||
type ServeMux struct{ *http.ServeMux }
|
||||
|
@ -24,26 +21,6 @@ func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
|
|||
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
|
||||
}
|
||||
|
||||
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
|
||||
|
||||
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "unix", env.DockerSocket)
|
||||
}
|
||||
|
||||
func dockerSocketHandler() http.HandlerFunc {
|
||||
rp := httputil.ReverseProxy{
|
||||
Director: func(r *http.Request) {
|
||||
r.URL.Scheme = "http"
|
||||
r.URL.Host = "api.moby.localhost"
|
||||
r.RequestURI = r.URL.String()
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
DialContext: dialDockerSocket,
|
||||
},
|
||||
}
|
||||
return rp.ServeHTTP
|
||||
}
|
||||
|
||||
func NewAgentHandler() http.Handler {
|
||||
mux := ServeMux{http.NewServeMux()}
|
||||
|
||||
|
@ -54,6 +31,6 @@ func NewAgentHandler() http.Handler {
|
|||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
||||
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
|
||||
mux.ServeMux.HandleFunc("/", dockerSocketHandler())
|
||||
mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler(env.DockerSocket))
|
||||
return mux
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||
"github.com/yusing/go-proxy/agent/pkg/handler"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/server"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
@ -33,12 +33,11 @@ func StartAgentServer(parent task.Parent, opt Options) {
|
|||
tlsConfig.ClientAuth = tls.NoClientCert
|
||||
}
|
||||
|
||||
logger := logging.GetLogger()
|
||||
agentServer := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", opt.Port),
|
||||
Handler: handler.NewAgentHandler(),
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
server.Start(parent, agentServer, nil, logger)
|
||||
server.Start(parent, agentServer, nil, &log.Logger)
|
||||
}
|
||||
|
|
11
cmd/main.go
11
cmd/main.go
|
@ -4,6 +4,7 @@ import (
|
|||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/auth"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
|
@ -35,8 +36,8 @@ func main() {
|
|||
initProfiling()
|
||||
|
||||
logging.InitLogger(os.Stderr, memlogger.GetMemLogger())
|
||||
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
||||
logging.Trace().Msg("trace enabled")
|
||||
log.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
||||
log.Trace().Msg("trace enabled")
|
||||
parallel(
|
||||
dnsproviders.InitProviders,
|
||||
homepage.InitIconListCache,
|
||||
|
@ -45,7 +46,7 @@ func main() {
|
|||
)
|
||||
|
||||
if common.APIJWTSecret == nil {
|
||||
logging.Warn().Msg("API_JWT_SECRET is not set, using random key")
|
||||
log.Warn().Msg("API_JWT_SECRET is not set, using random key")
|
||||
common.APIJWTSecret = common.RandomJWTKey()
|
||||
}
|
||||
|
||||
|
@ -62,7 +63,7 @@ func main() {
|
|||
Proxy: true,
|
||||
})
|
||||
if err := auth.Initialize(); err != nil {
|
||||
logging.Fatal().Err(err).Msg("failed to initialize authentication")
|
||||
log.Fatal().Err(err).Msg("failed to initialize authentication")
|
||||
}
|
||||
// API Handler needs to start after auth is initialized.
|
||||
cfg.StartServers(&config.StartServersOptions{
|
||||
|
@ -78,7 +79,7 @@ func main() {
|
|||
func prepareDirectory(dir string) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
|
||||
log.Fatal().Msgf("failed to create directory %s: %v", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
7
go.mod
7
go.mod
|
@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy/agent => ./agent
|
|||
|
||||
replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders
|
||||
|
||||
replace github.com/yusing/go-proxy/internal/utils => ./internal/utils
|
||||
|
||||
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.0.0-20250523122447-f078841dec22
|
||||
|
||||
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
|
||||
|
@ -45,7 +47,7 @@ require (
|
|||
github.com/stretchr/testify v1.10.0
|
||||
github.com/yusing/go-proxy/agent v0.0.0-00010101000000-000000000000
|
||||
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-00010101000000-000000000000
|
||||
go.uber.org/atomic v1.11.0
|
||||
github.com/yusing/go-proxy/internal/utils v0.0.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -219,6 +221,7 @@ require (
|
|||
go.opentelemetry.io/otel v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/mock v0.5.2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
|
@ -226,7 +229,7 @@ require (
|
|||
golang.org/x/mod v0.24.0 // indirect
|
||||
golang.org/x/sync v0.14.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.25.0
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
google.golang.org/api v0.234.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect
|
||||
|
|
|
@ -5,9 +5,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/maxmind"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -45,7 +45,7 @@ func (c *checkCache) Expired() bool {
|
|||
return c.created.Add(cacheTTL).Before(utils.TimeNow())
|
||||
}
|
||||
|
||||
//TODO: add stats
|
||||
// TODO: add stats
|
||||
|
||||
const (
|
||||
ACLAllow = "allow"
|
||||
|
@ -97,7 +97,7 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
|
|||
if c.valErr != nil {
|
||||
return c.valErr
|
||||
}
|
||||
logging.Info().
|
||||
log.Info().
|
||||
Str("default", c.Default).
|
||||
Bool("allow_local", c.allowLocal).
|
||||
Int("allow_rules", len(c.Allow)).
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
func TestMatchers(t *testing.T) {
|
||||
|
@ -16,7 +16,7 @@ func TestMatchers(t *testing.T) {
|
|||
}
|
||||
|
||||
var mathers Matchers
|
||||
err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
|
||||
err := serialization.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -22,12 +22,12 @@ func (noConn) SetDeadline(t time.Time) error { return nil }
|
|||
func (noConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (noConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (cfg *Config) WrapTCP(lis net.Listener) net.Listener {
|
||||
if cfg == nil {
|
||||
func (c *Config) WrapTCP(lis net.Listener) net.Listener {
|
||||
if c == nil {
|
||||
return lis
|
||||
}
|
||||
return &TCPListener{
|
||||
acl: cfg,
|
||||
acl: c,
|
||||
lis: lis,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,9 @@ package certapi
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
|
|||
gperr.LogError("failed to obtain cert", err)
|
||||
_ = gpwebsocket.WriteText(conn, err.Error())
|
||||
} else {
|
||||
logging.Info().Msg("cert obtained successfully")
|
||||
log.Info().Msg("cert obtained successfully")
|
||||
}
|
||||
}()
|
||||
for {
|
||||
|
|
|
@ -9,12 +9,13 @@ import (
|
|||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/pkg/stdcopy"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
// FIXME: agent logs not updating.
|
||||
func Logs(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
server := r.PathValue("server")
|
||||
|
@ -68,7 +69,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
|
|||
if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) {
|
||||
return
|
||||
}
|
||||
logging.Err(err).
|
||||
log.Err(err).
|
||||
Str("server", server).
|
||||
Str("container", containerID).
|
||||
Msg("failed to de-multiplex logs")
|
||||
|
|
|
@ -11,9 +11,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
|
@ -108,11 +108,11 @@ func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
|||
RefreshToken: token,
|
||||
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
||||
})
|
||||
logging.Debug().Str("username", username).Msg("stored oauth refresh token")
|
||||
log.Debug().Str("username", username).Msg("stored oauth refresh token")
|
||||
}
|
||||
|
||||
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
||||
logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
|
||||
log.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
|
||||
oauthRefreshTokens.Delete(string(sessionID))
|
||||
}
|
||||
|
||||
|
@ -127,7 +127,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
|
|||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
||||
signed, err := jwtToken.SignedString(common.APIJWTSecret)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to sign session token")
|
||||
log.Err(err).Msg("failed to sign session token")
|
||||
return
|
||||
}
|
||||
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
|
||||
|
@ -190,7 +190,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
|
|||
return nil, refreshToken.err
|
||||
}
|
||||
|
||||
idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken)
|
||||
idTokenJWT, idToken, err := auth.getIDToken(ctx, newToken)
|
||||
if err != nil {
|
||||
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
|
||||
return nil, refreshToken.err
|
||||
|
@ -205,7 +205,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
|
|||
|
||||
sessionID := newSessionID()
|
||||
|
||||
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||
log.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
||||
|
||||
refreshToken.result = &RefreshResult{
|
||||
|
|
|
@ -12,9 +12,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -38,8 +38,8 @@ type (
|
|||
|
||||
const (
|
||||
CookieOauthState = "godoxy_oidc_state"
|
||||
CookieOauthToken = "godoxy_oauth_token"
|
||||
CookieOauthSessionToken = "godoxy_session_token"
|
||||
CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
|
||||
CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -79,7 +79,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all
|
|||
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
|
||||
if err != nil && provider.EndSessionEndpoint() != "" {
|
||||
// non critical, just warn
|
||||
logging.Warn().
|
||||
log.Warn().
|
||||
Str("issuer", issuerURL).
|
||||
Err(err).
|
||||
Msg("failed to parse end session URL")
|
||||
|
@ -129,7 +129,7 @@ func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
|||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
||||
func (auth *OIDCProvider) getIDToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
||||
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", nil, errMissingIDToken
|
||||
|
@ -176,7 +176,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// clear cookies then redirect to home
|
||||
logging.Err(err).Msg("failed to refresh token")
|
||||
log.Err(err).Msg("failed to refresh token")
|
||||
auth.clearCookie(w, r)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
|
@ -201,11 +201,12 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
|||
|
||||
func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
||||
userAllowed := slices.Contains(auth.allowedUsers, user)
|
||||
if !userAllowed {
|
||||
return false
|
||||
if userAllowed {
|
||||
return true
|
||||
}
|
||||
if len(auth.allowedGroups) == 0 {
|
||||
return true
|
||||
// user is not allowed, but no groups are allowed
|
||||
return false
|
||||
}
|
||||
return len(utils.Intersect(groups, auth.allowedGroups)) > 0
|
||||
}
|
||||
|
@ -257,7 +258,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
|
|||
return
|
||||
}
|
||||
|
||||
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token)
|
||||
idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token)
|
||||
if err != nil {
|
||||
gphttp.ServerError(w, r, err)
|
||||
return
|
||||
|
|
|
@ -5,13 +5,16 @@ import (
|
|||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
|
||||
"github.com/go-acme/lego/v4/certcrypto"
|
||||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -22,13 +25,19 @@ type Config struct {
|
|||
KeyPath string `json:"key_path,omitempty"`
|
||||
ACMEKeyPath string `json:"acme_key_path,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
CADirURL string `json:"ca_dir_url,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
|
||||
HTTPClient *http.Client `json:"-"` // for tests only
|
||||
|
||||
challengeProvider challenge.Provider
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMissingDomain = gperr.New("missing field 'domains'")
|
||||
ErrMissingEmail = gperr.New("missing field 'email'")
|
||||
ErrMissingProvider = gperr.New("missing field 'provider'")
|
||||
ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
|
||||
ErrInvalidDomain = gperr.New("invalid domain")
|
||||
ErrUnknownProvider = gperr.New("unknown provider")
|
||||
)
|
||||
|
@ -36,6 +45,7 @@ var (
|
|||
const (
|
||||
ProviderLocal = "local"
|
||||
ProviderPseudo = "pseudo"
|
||||
ProviderCustom = "custom"
|
||||
)
|
||||
|
||||
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
||||
|
@ -52,6 +62,10 @@ func (cfg *Config) Validate() gperr.Error {
|
|||
}
|
||||
|
||||
b := gperr.NewBuilder("autocert errors")
|
||||
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
|
||||
b.Add(ErrMissingCADirURL)
|
||||
}
|
||||
|
||||
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||
if len(cfg.Domains) == 0 {
|
||||
b.Add(ErrMissingDomain)
|
||||
|
@ -59,24 +73,34 @@ func (cfg *Config) Validate() gperr.Error {
|
|||
if cfg.Email == "" {
|
||||
b.Add(ErrMissingEmail)
|
||||
}
|
||||
for i, d := range cfg.Domains {
|
||||
if !domainOrWildcardRE.MatchString(d) {
|
||||
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
||||
if cfg.Provider != ProviderCustom {
|
||||
for i, d := range cfg.Domains {
|
||||
if !domainOrWildcardRE.MatchString(d) {
|
||||
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
||||
}
|
||||
}
|
||||
}
|
||||
// check if provider is implemented
|
||||
providerConstructor, ok := Providers[cfg.Provider]
|
||||
if !ok {
|
||||
b.Add(ErrUnknownProvider.
|
||||
Subject(cfg.Provider).
|
||||
With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
|
||||
if cfg.Provider != ProviderCustom {
|
||||
b.Add(ErrUnknownProvider.
|
||||
Subject(cfg.Provider).
|
||||
With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
|
||||
}
|
||||
} else {
|
||||
_, err := providerConstructor(cfg.Options)
|
||||
provider, err := providerConstructor(cfg.Options)
|
||||
if err != nil {
|
||||
b.Add(err)
|
||||
} else {
|
||||
cfg.challengeProvider = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.challengeProvider == nil {
|
||||
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
|
||||
}
|
||||
return b.Error()
|
||||
}
|
||||
|
||||
|
@ -100,8 +124,7 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
|||
|
||||
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||
if privKey, err = cfg.LoadACMEKey(); err != nil {
|
||||
logging.Info().Err(err).Msg("load ACME private key failed")
|
||||
logging.Info().Msg("generate new ACME private key")
|
||||
log.Info().Err(err).Msg("failed to load ACME private key, generating a now one")
|
||||
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, gperr.New("generate ACME private key").With(err)
|
||||
|
@ -118,12 +141,23 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
|||
}
|
||||
|
||||
legoCfg := lego.NewConfig(user)
|
||||
legoCfg.Certificate.KeyType = certcrypto.RSA2048
|
||||
legoCfg.Certificate.KeyType = certcrypto.EC256
|
||||
|
||||
if cfg.HTTPClient != nil {
|
||||
legoCfg.HTTPClient = cfg.HTTPClient
|
||||
}
|
||||
|
||||
if cfg.CADirURL != "" {
|
||||
legoCfg.CADirURL = cfg.CADirURL
|
||||
}
|
||||
|
||||
return user, legoCfg, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||
if common.IsTest {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
data, err := os.ReadFile(cfg.ACMEKeyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -132,6 +166,9 @@ func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
|||
}
|
||||
|
||||
func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error {
|
||||
if common.IsTest {
|
||||
return nil
|
||||
}
|
||||
data, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -5,18 +5,20 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/certificate"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
@ -76,13 +78,11 @@ func (p *Provider) ObtainCert() error {
|
|||
}
|
||||
|
||||
if p.cfg.Provider == ProviderPseudo {
|
||||
t := time.NewTicker(1000 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
logging.Info().Msg("init client for pseudo provider")
|
||||
<-t.C
|
||||
logging.Info().Msg("registering acme for pseudo provider")
|
||||
<-t.C
|
||||
logging.Info().Msg("obtained cert for pseudo provider")
|
||||
log.Info().Msg("init client for pseudo provider")
|
||||
<-time.After(time.Second)
|
||||
log.Info().Msg("registering acme for pseudo provider")
|
||||
<-time.After(time.Second)
|
||||
log.Info().Msg("obtained cert for pseudo provider")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ func (p *Provider) ObtainCert() error {
|
|||
})
|
||||
if err != nil {
|
||||
p.legoCert = nil
|
||||
logging.Err(err).Msg("cert renew failed, fallback to obtain")
|
||||
log.Err(err).Msg("cert renew failed, fallback to obtain")
|
||||
} else {
|
||||
p.legoCert = cert
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func (p *Provider) LoadCert() error {
|
|||
p.tlsCert = &cert
|
||||
p.certExpiries = expiries
|
||||
|
||||
logging.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
log.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
return p.renewIfNeeded()
|
||||
}
|
||||
|
||||
|
@ -219,13 +219,7 @@ func (p *Provider) initClient() error {
|
|||
return err
|
||||
}
|
||||
|
||||
generator := Providers[p.cfg.Provider]
|
||||
legoProvider, pErr := generator(p.cfg.Options)
|
||||
if pErr != nil {
|
||||
return pErr
|
||||
}
|
||||
|
||||
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
||||
err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -240,7 +234,7 @@ func (p *Provider) registerACME() error {
|
|||
}
|
||||
if reg, err := p.client.Registration.ResolveAccountByKey(); err == nil {
|
||||
p.user.Registration = reg
|
||||
logging.Info().Msg("reused acme registration from private key")
|
||||
log.Info().Msg("reused acme registration from private key")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -249,11 +243,14 @@ func (p *Provider) registerACME() error {
|
|||
return err
|
||||
}
|
||||
p.user.Registration = reg
|
||||
logging.Info().Interface("reg", reg).Msg("acme registered")
|
||||
log.Info().Interface("reg", reg).Msg("acme registered")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) saveCert(cert *certificate.Resource) error {
|
||||
if common.IsTest {
|
||||
return nil
|
||||
}
|
||||
/* This should have been done in setup
|
||||
but double check is always a good choice.*/
|
||||
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||
|
@ -283,22 +280,19 @@ func (p *Provider) certState() CertState {
|
|||
return CertStateExpired
|
||||
}
|
||||
|
||||
certDomains := make([]string, len(p.certExpiries))
|
||||
wantedDomains := make([]string, len(p.cfg.Domains))
|
||||
i := 0
|
||||
for domain := range p.certExpiries {
|
||||
certDomains[i] = domain
|
||||
i++
|
||||
}
|
||||
copy(wantedDomains, p.cfg.Domains)
|
||||
sort.Strings(wantedDomains)
|
||||
sort.Strings(certDomains)
|
||||
|
||||
if !reflect.DeepEqual(certDomains, wantedDomains) {
|
||||
logging.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
|
||||
if len(p.certExpiries) != len(p.cfg.Domains) {
|
||||
return CertStateMismatch
|
||||
}
|
||||
|
||||
for i := range len(p.cfg.Domains) {
|
||||
if _, ok := p.certExpiries[p.cfg.Domains[i]]; !ok {
|
||||
log.Info().Msgf("autocert domains mismatch: cert: %s, wanted: %s",
|
||||
strings.Join(slices.Collect(maps.Keys(p.certExpiries)), ", "),
|
||||
strings.Join(p.cfg.Domains, ", "))
|
||||
return CertStateMismatch
|
||||
}
|
||||
}
|
||||
|
||||
return CertStateValid
|
||||
}
|
||||
|
||||
|
@ -309,9 +303,9 @@ func (p *Provider) renewIfNeeded() error {
|
|||
|
||||
switch p.certState() {
|
||||
case CertStateExpired:
|
||||
logging.Info().Msg("certs expired, renewing")
|
||||
log.Info().Msg("certs expired, renewing")
|
||||
case CertStateMismatch:
|
||||
logging.Info().Msg("cert domains mismatch with config, renewing")
|
||||
log.Info().Msg("cert domains mismatch with config, renewing")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
|
453
internal/autocert/provider_test/custom_test.go
Normal file
453
internal/autocert/provider_test/custom_test.go
Normal file
|
@ -0,0 +1,453 @@
|
|||
//nolint:errchkjson,errcheck
|
||||
package provider_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/dnsproviders"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
dnsproviders.InitProviders()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestCustomProvider(t *testing.T) {
|
||||
t.Run("valid custom provider with step-ca", func(t *testing.T) {
|
||||
cfg := &autocert.Config{
|
||||
Email: "test@example.com",
|
||||
Domains: []string{"example.com", "*.example.com"},
|
||||
Provider: autocert.ProviderCustom,
|
||||
CADirURL: "https://ca.example.com:9000/acme/acme/directory",
|
||||
CertPath: "certs/custom.crt",
|
||||
KeyPath: "certs/custom.key",
|
||||
ACMEKeyPath: "certs/custom-acme.key",
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.NoError(t, err)
|
||||
|
||||
user, legoCfg, err := cfg.GetLegoConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.NotNil(t, legoCfg)
|
||||
require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", legoCfg.CADirURL)
|
||||
require.Equal(t, "test@example.com", user.Email)
|
||||
})
|
||||
|
||||
t.Run("custom provider missing CADirURL", func(t *testing.T) {
|
||||
cfg := &autocert.Config{
|
||||
Email: "test@example.com",
|
||||
Domains: []string{"example.com"},
|
||||
Provider: autocert.ProviderCustom,
|
||||
// CADirURL is missing
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
|
||||
})
|
||||
|
||||
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
|
||||
cfg := &autocert.Config{
|
||||
Email: "admin@internal.com",
|
||||
Domains: []string{"internal.example.com", "api.internal.example.com"},
|
||||
Provider: autocert.ProviderCustom,
|
||||
CADirURL: "https://step-ca.internal:443/acme/acme/directory",
|
||||
CertPath: "certs/internal.crt",
|
||||
KeyPath: "certs/internal.key",
|
||||
ACMEKeyPath: "certs/internal-acme.key",
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
require.NoError(t, err)
|
||||
|
||||
user, legoCfg, err := cfg.GetLegoConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.NotNil(t, legoCfg)
|
||||
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
|
||||
require.Equal(t, "admin@internal.com", user.Email)
|
||||
|
||||
provider := autocert.NewProvider(cfg, user, legoCfg)
|
||||
require.NotNil(t, provider)
|
||||
require.Equal(t, autocert.ProviderCustom, provider.GetName())
|
||||
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
|
||||
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
|
||||
})
|
||||
}
|
||||
|
||||
func TestObtainCertFromCustomProvider(t *testing.T) {
|
||||
// Create a test ACME server
|
||||
acmeServer := newTestACMEServer(t)
|
||||
defer acmeServer.Close()
|
||||
|
||||
t.Run("obtain cert from custom step-ca server", func(t *testing.T) {
|
||||
cfg := &autocert.Config{
|
||||
Email: "test@example.com",
|
||||
Domains: []string{"test.example.com"},
|
||||
Provider: autocert.ProviderCustom,
|
||||
CADirURL: acmeServer.URL() + "/acme/acme/directory",
|
||||
CertPath: "certs/stepca-test.crt",
|
||||
KeyPath: "certs/stepca-test.key",
|
||||
ACMEKeyPath: "certs/stepca-test-acme.key",
|
||||
HTTPClient: acmeServer.httpClient(),
|
||||
}
|
||||
|
||||
err := error(cfg.Validate())
|
||||
require.NoError(t, err)
|
||||
|
||||
user, legoCfg, err := cfg.GetLegoConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.NotNil(t, legoCfg)
|
||||
|
||||
provider := autocert.NewProvider(cfg, user, legoCfg)
|
||||
require.NotNil(t, provider)
|
||||
|
||||
// Test obtaining certificate
|
||||
err = provider.ObtainCert()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify certificate was obtained
|
||||
cert, err := provider.GetCert(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
|
||||
// Verify certificate properties
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, x509Cert.DNSNames, "test.example.com")
|
||||
require.True(t, time.Now().Before(x509Cert.NotAfter))
|
||||
require.True(t, time.Now().After(x509Cert.NotBefore))
|
||||
})
|
||||
}
|
||||
|
||||
// testACMEServer implements a minimal ACME server for testing.
|
||||
type testACMEServer struct {
|
||||
server *httptest.Server
|
||||
caCert *x509.Certificate
|
||||
caKey *rsa.PrivateKey
|
||||
clientCSRs map[string]*x509.CertificateRequest
|
||||
orderID string
|
||||
}
|
||||
|
||||
func newTestACMEServer(t *testing.T) *testACMEServer {
|
||||
t.Helper()
|
||||
|
||||
// Generate CA certificate and key
|
||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test CA"},
|
||||
Country: []string{"US"},
|
||||
Province: []string{""},
|
||||
Locality: []string{"Test"},
|
||||
StreetAddress: []string{""},
|
||||
PostalCode: []string{""},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
caCert, err := x509.ParseCertificate(caCertDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
acme := &testACMEServer{
|
||||
caCert: caCert,
|
||||
caKey: caKey,
|
||||
clientCSRs: make(map[string]*x509.CertificateRequest),
|
||||
orderID: "test-order-123",
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
acme.setupRoutes(mux)
|
||||
|
||||
acme.server = httptest.NewTLSServer(mux)
|
||||
return acme
|
||||
}
|
||||
|
||||
func (s *testACMEServer) Close() {
|
||||
s.server.Close()
|
||||
}
|
||||
|
||||
func (s *testACMEServer) URL() string {
|
||||
return s.server.URL
|
||||
}
|
||||
|
||||
func (s *testACMEServer) httpClient() *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
|
||||
// ACME directory endpoint
|
||||
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
|
||||
|
||||
// ACME endpoints
|
||||
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
|
||||
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
|
||||
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
|
||||
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
|
||||
mux.HandleFunc("/acme/chall/", s.handleChallenge)
|
||||
mux.HandleFunc("/acme/order/", s.handleOrder)
|
||||
mux.HandleFunc("/acme/cert/", s.handleCertificate)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
directory := map[string]interface{}{
|
||||
"newNonce": s.server.URL + "/acme/new-nonce",
|
||||
"newAccount": s.server.URL + "/acme/new-account",
|
||||
"newOrder": s.server.URL + "/acme/new-order",
|
||||
"keyChange": s.server.URL + "/acme/key-change",
|
||||
"meta": map[string]interface{}{
|
||||
"termsOfService": s.server.URL + "/terms",
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(directory)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-12345")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
account := map[string]interface{}{
|
||||
"status": "valid",
|
||||
"contact": []string{"mailto:test@example.com"},
|
||||
"orders": s.server.URL + "/acme/orders",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Location", s.server.URL+"/acme/account/1")
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-67890")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(account)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
authzID := "test-authz-456"
|
||||
|
||||
order := map[string]interface{}{
|
||||
"status": "ready", // Skip pending state for simplicity
|
||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
||||
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
|
||||
"finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-order")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(order)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
authz := map[string]interface{}{
|
||||
"status": "valid", // Skip challenge validation for simplicity
|
||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||
"identifier": map[string]string{"type": "dns", "value": "test.example.com"},
|
||||
"challenges": []map[string]interface{}{
|
||||
{
|
||||
"type": "dns-01",
|
||||
"status": "valid",
|
||||
"url": s.server.URL + "/acme/chall/test-chall-789",
|
||||
"token": "test-token-abc123",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-authz")
|
||||
json.NewEncoder(w).Encode(authz)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
challenge := map[string]interface{}{
|
||||
"type": "dns-01",
|
||||
"status": "valid",
|
||||
"url": r.URL.String(),
|
||||
"token": "test-token-abc123",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-chall")
|
||||
json.NewEncoder(w).Encode(challenge)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/finalize") {
|
||||
s.handleFinalize(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
certURL := s.server.URL + "/acme/cert/" + s.orderID
|
||||
order := map[string]interface{}{
|
||||
"status": "valid",
|
||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
||||
"certificate": certURL,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
|
||||
json.NewEncoder(w).Encode(order)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
|
||||
// Read the JWS payload
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract CSR from JWS payload
|
||||
csr, err := s.extractCSRFromJWS(body)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store the CSR for certificate generation
|
||||
s.clientCSRs[s.orderID] = csr
|
||||
|
||||
certURL := s.server.URL + "/acme/cert/" + s.orderID
|
||||
order := map[string]interface{}{
|
||||
"status": "valid",
|
||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
||||
"certificate": certURL,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
|
||||
json.NewEncoder(w).Encode(order)
|
||||
}
|
||||
|
||||
func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
|
||||
// Parse the JWS structure
|
||||
var jws struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jwsData, &jws); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode the payload
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the finalize request
|
||||
var finalizeReq struct {
|
||||
CSR string `json:"csr"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode the CSR
|
||||
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the CSR
|
||||
csr, err := x509.ParseCertificateRequest(csrBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return csr, nil
|
||||
}
|
||||
|
||||
func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract order ID from URL
|
||||
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
|
||||
|
||||
// Get the CSR for this order
|
||||
csr, exists := s.clientCSRs[orderID]
|
||||
if !exists {
|
||||
http.Error(w, "No CSR found for order", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create certificate using the public key from the client's CSR
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Cert"},
|
||||
Country: []string{"US"},
|
||||
},
|
||||
DNSNames: csr.DNSNames,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(90 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
// Use the public key from the CSR and sign with CA key
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Return certificate chain
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
|
||||
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||
w.Header().Set("Replay-Nonce", "test-nonce-cert")
|
||||
w.Write(append(certPEM, caPEM...))
|
||||
}
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/go-acme/lego/v4/providers/dns/ovh"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
// type Config struct {
|
||||
|
@ -45,6 +45,6 @@ oauth2_config:
|
|||
testYaml = testYaml[1:] // remove first \n
|
||||
opt := make(map[string]any)
|
||||
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
|
||||
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
|
||||
require.NoError(t, serialization.MapUnmarshalValidate(opt, cfg))
|
||||
require.Equal(t, cfgExpected, cfg)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package autocert
|
|||
import (
|
||||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type Generator func(map[string]any) (challenge.Provider, gperr.Error)
|
||||
|
@ -16,9 +16,11 @@ func DNSProvider[CT any, PT challenge.Provider](
|
|||
) Generator {
|
||||
return func(opt map[string]any) (challenge.Provider, gperr.Error) {
|
||||
cfg := defaultCfg()
|
||||
err := utils.MapUnmarshalValidate(opt, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(opt) > 0 {
|
||||
err := serialization.MapUnmarshalValidate(opt, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
p, pErr := newProvider(cfg)
|
||||
return p, gperr.Wrap(pErr)
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
|
@ -13,14 +13,14 @@ func (p *Provider) Setup() (err error) {
|
|||
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
|
||||
return err
|
||||
}
|
||||
logging.Debug().Msg("obtaining cert due to error loading cert")
|
||||
log.Debug().Msg("obtaining cert due to error loading cert")
|
||||
if err = p.ObtainCert(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, expiry := range p.GetExpiries() {
|
||||
logging.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
|
||||
log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
|
||||
break
|
||||
}
|
||||
|
||||
|
|
|
@ -10,20 +10,20 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/api"
|
||||
autocert "github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/entrypoint"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/maxmind"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/server"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/proxmox"
|
||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
|
@ -96,10 +96,10 @@ func OnConfigChange(ev []events.Event) {
|
|||
// just reload once and check the last event
|
||||
switch ev[len(ev)-1].Action {
|
||||
case events.ActionFileRenamed:
|
||||
logging.Warn().Msg(cfgRenameWarn)
|
||||
log.Warn().Msg(cfgRenameWarn)
|
||||
return
|
||||
case events.ActionFileDeleted:
|
||||
logging.Warn().Msg(cfgDeleteWarn)
|
||||
log.Warn().Msg(cfgDeleteWarn)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -161,7 +161,7 @@ func (cfg *Config) Start(opts ...*StartServersOptions) {
|
|||
func (cfg *Config) StartAutoCert() {
|
||||
autocert := cfg.autocertProvider
|
||||
if autocert == nil {
|
||||
logging.Info().Msg("autocert not configured")
|
||||
log.Info().Msg("autocert not configured")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -223,7 +223,7 @@ func (cfg *Config) load() gperr.Error {
|
|||
}
|
||||
|
||||
model := config.DefaultConfig()
|
||||
if err := utils.UnmarshalValidateYAML(data, model); err != nil {
|
||||
if err := serialization.UnmarshalValidateYAML(data, model); err != nil {
|
||||
gperr.LogFatal(errMsg, err)
|
||||
}
|
||||
|
||||
|
@ -374,6 +374,6 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
|
|||
}
|
||||
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
|
||||
})
|
||||
logging.Info().Msg(results.String())
|
||||
log.Info().Msg(results.String())
|
||||
return errs.Error()
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/proxmox"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -93,14 +93,14 @@ func HasInstance() bool {
|
|||
|
||||
func Validate(data []byte) gperr.Error {
|
||||
var model Config
|
||||
return utils.UnmarshalValidateYAML(data, &model)
|
||||
return serialization.UnmarshalValidateYAML(data, &model)
|
||||
}
|
||||
|
||||
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
|
||||
|
||||
func init() {
|
||||
utils.RegisterDefaultValueFactory(DefaultConfig)
|
||||
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
|
||||
serialization.RegisterDefaultValueFactory(DefaultConfig)
|
||||
serialization.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
|
||||
domains := fl.Field().Interface().([]string)
|
||||
for _, domain := range domains {
|
||||
if !matchDomainsRegex.MatchString(domain) {
|
||||
|
@ -109,7 +109,7 @@ func init() {
|
|||
}
|
||||
return true
|
||||
})
|
||||
utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
|
||||
serialization.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
|
||||
m := fl.Field().Interface().(map[string]string)
|
||||
for k := range m {
|
||||
if k == "" {
|
||||
|
|
|
@ -4,6 +4,8 @@ go 1.24.3
|
|||
|
||||
replace github.com/yusing/go-proxy => ../..
|
||||
|
||||
replace github.com/yusing/go-proxy/internal/utils => ../utils
|
||||
|
||||
require (
|
||||
github.com/go-acme/lego/v4 v4.23.1
|
||||
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
|
||||
|
@ -156,6 +158,7 @@ require (
|
|||
github.com/vultr/govultr/v3 v3.20.0 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusing/go-proxy/internal/utils v0.0.0 // indirect
|
||||
go.mongodb.org/mongo-driver v1.17.3 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
|
||||
|
|
|
@ -13,13 +13,14 @@ import (
|
|||
|
||||
"github.com/docker/cli/cli/connhelper"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
// TODO: implement reconnect here.
|
||||
type (
|
||||
SharedClient struct {
|
||||
*client.Client
|
||||
|
@ -83,7 +84,7 @@ func closeTimedOutClients() {
|
|||
if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs {
|
||||
delete(clientMap, c.Key())
|
||||
c.Client.Close()
|
||||
logging.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
|
||||
log.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -148,7 +149,7 @@ func NewClient(host string) (*SharedClient, error) {
|
|||
default:
|
||||
helper, err := connhelper.GetConnectionHelper(host)
|
||||
if err != nil {
|
||||
logging.Panic().Err(err).Msg("failed to get connection helper")
|
||||
log.Panic().Err(err).Msg("failed to get connection helper")
|
||||
}
|
||||
if helper != nil {
|
||||
httpClient := &http.Client{
|
||||
|
@ -189,10 +190,10 @@ func NewClient(host string) (*SharedClient, error) {
|
|||
c.dial = client.Dialer()
|
||||
}
|
||||
if c.addr == "" {
|
||||
c.addr = c.Client.DaemonHost()
|
||||
c.addr = c.DaemonHost()
|
||||
}
|
||||
|
||||
defer logging.Debug().Str("host", host).Msg("docker client initialized")
|
||||
defer log.Debug().Str("host", host).Msg("docker client initialized")
|
||||
|
||||
clientMap[c.Key()] = c
|
||||
return c, nil
|
||||
|
|
|
@ -9,11 +9,12 @@ import (
|
|||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/go-connections/nat"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -90,7 +91,7 @@ func FromDocker(c *container.SummaryTrimmed, dockerHost string) (res *Container)
|
|||
var ok bool
|
||||
res.Agent, ok = config.GetInstance().GetAgent(dockerHost)
|
||||
if !ok {
|
||||
logging.Error().Msgf("agent %q not found", dockerHost)
|
||||
log.Error().Msgf("agent %q not found", dockerHost)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -183,7 +184,7 @@ func (c *Container) setPublicHostname() {
|
|||
}
|
||||
url, err := url.Parse(c.DockerHost)
|
||||
if err != nil {
|
||||
logging.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
|
||||
log.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
|
||||
c.PublicHostname = "127.0.0.1"
|
||||
return
|
||||
}
|
||||
|
@ -224,7 +225,7 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
|
|||
ContainerName: c.ContainerName,
|
||||
},
|
||||
}
|
||||
err := utils.MapUnmarshalValidate(cfg, idwCfg)
|
||||
err := serialization.MapUnmarshalValidate(cfg, idwCfg)
|
||||
if err != nil {
|
||||
gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err))
|
||||
} else {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||
|
@ -50,7 +50,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
|||
}
|
||||
ep.middleware = mid
|
||||
|
||||
logging.Debug().Msg("entrypoint middleware loaded")
|
||||
log.Debug().Msg("entrypoint middleware loaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
logging.Debug().Msg("entrypoint access logger created")
|
||||
log.Debug().Msg("entrypoint access logger created")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
// Then scraper / scanners will know the subdomain is invalid.
|
||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||
logging.Err(err).
|
||||
log.Err(err).
|
||||
Str("method", r.Method).
|
||||
Str("url", r.URL.String()).
|
||||
Str("remote", r.RemoteAddr).
|
||||
|
@ -99,7 +99,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if _, err := w.Write(errorPage); err != nil {
|
||||
logging.Err(err).Msg("failed to write error page")
|
||||
log.Err(err).Msg("failed to write error page")
|
||||
}
|
||||
} else {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
zerologlog "github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
|
||||
|
@ -13,7 +13,7 @@ func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger)
|
|||
if len(logger) > 0 {
|
||||
l = logger[0]
|
||||
} else {
|
||||
l = logging.GetLogger()
|
||||
l = &zerologlog.Logger
|
||||
}
|
||||
l.WithLevel(level).Msg(New(highlightANSI(msg)).With(err).Error())
|
||||
switch level {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"slices"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/homepage/widgets"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -32,7 +32,7 @@ type (
|
|||
)
|
||||
|
||||
func init() {
|
||||
utils.RegisterDefaultValueFactory(func() *ItemConfig {
|
||||
serialization.RegisterDefaultValueFactory(func() *ItemConfig {
|
||||
return &ItemConfig{
|
||||
Show: true,
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||
|
@ -74,7 +74,7 @@ func pruneExpiredIconCache() {
|
|||
}
|
||||
}
|
||||
if nPruned > 0 {
|
||||
logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
|
||||
log.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,7 +87,7 @@ func loadIconCache(key string) *FetchResult {
|
|||
defer iconMu.RUnlock()
|
||||
icon, ok := iconCache.Load(key)
|
||||
if ok && len(icon.Icon) > 0 {
|
||||
logging.Debug().
|
||||
log.Debug().
|
||||
Str("key", key).
|
||||
Msg("icon found in cache")
|
||||
icon.LastAccess.Store(utils.TimeNow())
|
||||
|
@ -99,7 +99,7 @@ func loadIconCache(key string) *FetchResult {
|
|||
func storeIconCache(key string, result *FetchResult) {
|
||||
icon := result.Icon
|
||||
if len(icon) > maxIconSize {
|
||||
logging.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
|
||||
log.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -109,7 +109,7 @@ func storeIconCache(key string, result *FetchResult) {
|
|||
entry := &cacheEntry{Icon: icon, ContentType: result.contentType}
|
||||
entry.LastAccess.Store(time.Now())
|
||||
iconCache.Store(key, entry)
|
||||
logging.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
|
||||
log.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
|
||||
}
|
||||
|
||||
func (e *cacheEntry) IsExpired() bool {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package homepage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -10,10 +11,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
|
@ -46,30 +47,30 @@ type (
|
|||
func (icon *IconMeta) Filenames(ref string) []string {
|
||||
filenames := make([]string, 0)
|
||||
if icon.SVG {
|
||||
filenames = append(filenames, fmt.Sprintf("%s.svg", ref))
|
||||
filenames = append(filenames, ref+".svg")
|
||||
if icon.Light {
|
||||
filenames = append(filenames, fmt.Sprintf("%s-light.svg", ref))
|
||||
filenames = append(filenames, ref+"-light.svg")
|
||||
}
|
||||
if icon.Dark {
|
||||
filenames = append(filenames, fmt.Sprintf("%s-dark.svg", ref))
|
||||
filenames = append(filenames, ref+"-dark.svg")
|
||||
}
|
||||
}
|
||||
if icon.PNG {
|
||||
filenames = append(filenames, fmt.Sprintf("%s.png", ref))
|
||||
filenames = append(filenames, ref+".png")
|
||||
if icon.Light {
|
||||
filenames = append(filenames, fmt.Sprintf("%s-light.png", ref))
|
||||
filenames = append(filenames, ref+"-light.png")
|
||||
}
|
||||
if icon.Dark {
|
||||
filenames = append(filenames, fmt.Sprintf("%s-dark.png", ref))
|
||||
filenames = append(filenames, ref+"-dark.png")
|
||||
}
|
||||
}
|
||||
if icon.WebP {
|
||||
filenames = append(filenames, fmt.Sprintf("%s.webp", ref))
|
||||
filenames = append(filenames, ref+".webp")
|
||||
if icon.Light {
|
||||
filenames = append(filenames, fmt.Sprintf("%s-light.webp", ref))
|
||||
filenames = append(filenames, ref+"-light.webp")
|
||||
}
|
||||
if icon.Dark {
|
||||
filenames = append(filenames, fmt.Sprintf("%s-dark.webp", ref))
|
||||
filenames = append(filenames, ref+"-dark.webp")
|
||||
}
|
||||
}
|
||||
return filenames
|
||||
|
@ -99,21 +100,21 @@ func InitIconListCache() {
|
|||
iconsCache.Lock()
|
||||
defer iconsCache.Unlock()
|
||||
|
||||
err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache)
|
||||
err := serialization.LoadJSONIfExist(common.IconListCachePath, iconsCache)
|
||||
if err != nil {
|
||||
logging.Error().Err(err).Msg("failed to load icons")
|
||||
log.Error().Err(err).Msg("failed to load icons")
|
||||
} else if len(iconsCache.Icons) > 0 {
|
||||
logging.Info().
|
||||
log.Info().
|
||||
Int("icons", len(iconsCache.Icons)).
|
||||
Msg("icons loaded")
|
||||
}
|
||||
|
||||
if err = updateIcons(); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to update icons")
|
||||
log.Error().Err(err).Msg("failed to update icons")
|
||||
}
|
||||
|
||||
task.OnProgramExit("save_icons_cache", func() {
|
||||
utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
||||
_ = serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -134,17 +135,17 @@ func ListAvailableIcons() (*Cache, error) {
|
|||
iconsCache.Lock()
|
||||
defer iconsCache.Unlock()
|
||||
|
||||
logging.Info().Msg("updating icon data")
|
||||
log.Info().Msg("updating icon data")
|
||||
if err := updateIcons(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logging.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
|
||||
log.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
|
||||
|
||||
iconsCache.LastUpdate = time.Now()
|
||||
|
||||
err := utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
||||
err := serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msg("failed to save icons")
|
||||
log.Warn().Err(err).Msg("failed to save icons")
|
||||
}
|
||||
return iconsCache, nil
|
||||
}
|
||||
|
@ -230,14 +231,17 @@ func updateIcons() error {
|
|||
|
||||
var httpGet = httpGetImpl
|
||||
|
||||
func MockHttpGet(body []byte) {
|
||||
func MockHTTPGet(body []byte) {
|
||||
httpGet = func(_ string) ([]byte, error) {
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
func httpGetImpl(url string) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -347,7 +351,7 @@ func UpdateSelfhstIcons() error {
|
|||
}
|
||||
|
||||
data := make([]SelfhStIcon, 0)
|
||||
err = json.Unmarshal(body, &data)
|
||||
err = json.Unmarshal(body, &data) //nolint:musttag
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -68,6 +68,8 @@ type testCases struct {
|
|||
}
|
||||
|
||||
func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
|
||||
t.Helper()
|
||||
|
||||
for _, item := range test {
|
||||
icon, ok := iconsCache.Icons[item.Key]
|
||||
if !ok {
|
||||
|
@ -89,7 +91,7 @@ func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
|
|||
}
|
||||
|
||||
func TestListWalkxCodeIcons(t *testing.T) {
|
||||
MockHttpGet([]byte(walkxcodeIcons))
|
||||
MockHTTPGet([]byte(walkxcodeIcons))
|
||||
if err := UpdateWalkxCodeIcons(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -122,7 +124,7 @@ func TestListWalkxCodeIcons(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestListSelfhstIcons(t *testing.T) {
|
||||
MockHttpGet([]byte(selfhstIcons))
|
||||
MockHTTPGet([]byte(selfhstIcons))
|
||||
if err := UpdateSelfhstIcons(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -33,17 +33,18 @@ var widgetProviders = map[string]struct{}{
|
|||
var ErrInvalidProvider = gperr.New("invalid provider")
|
||||
|
||||
func (cfg *Config) UnmarshalMap(m map[string]any) error {
|
||||
cfg.Provider = m["provider"].(string)
|
||||
var ok bool
|
||||
cfg.Provider, ok = m["provider"].(string)
|
||||
if !ok {
|
||||
return ErrInvalidProvider.Withf("non string")
|
||||
}
|
||||
if _, ok := widgetProviders[cfg.Provider]; !ok {
|
||||
return ErrInvalidProvider.Subject(cfg.Provider)
|
||||
}
|
||||
delete(m, "provider")
|
||||
m, ok := m["config"].(map[string]any)
|
||||
m, ok = m["config"].(map[string]any)
|
||||
if !ok {
|
||||
return gperr.New("invalid config")
|
||||
}
|
||||
if err := utils.MapUnmarshalValidate(m, &cfg.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return serialization.MapUnmarshalValidate(m, &cfg.Config)
|
||||
}
|
||||
|
|
|
@ -7,10 +7,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/idlewatcher/provider"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
|
@ -73,13 +73,13 @@ var dummyHealthCheckConfig = &health.HealthCheckConfig{
|
|||
}
|
||||
|
||||
var (
|
||||
causeReload = gperr.New("reloaded")
|
||||
causeContainerDestroy = gperr.New("container destroyed")
|
||||
causeReload = gperr.New("reloaded") //nolint:errname
|
||||
causeContainerDestroy = gperr.New("container destroyed") //nolint:errname
|
||||
)
|
||||
|
||||
const reqTimeout = 3 * time.Second
|
||||
|
||||
// TODO: fix stream type
|
||||
// TODO: fix stream type.
|
||||
func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
||||
cfg := r.IdlewatcherConfig()
|
||||
key := cfg.Key()
|
||||
|
@ -120,7 +120,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
|||
return nil, err
|
||||
}
|
||||
w.provider = p
|
||||
w.l = logging.With().
|
||||
w.l = log.With().
|
||||
Str("provider", providerType).
|
||||
Str("container", cfg.ContainerName()).
|
||||
Logger()
|
||||
|
|
|
@ -2,18 +2,17 @@ package jsonstore
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
|
||||
"maps"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type namespace string
|
||||
|
@ -36,13 +35,15 @@ type store interface {
|
|||
json.Unmarshaler
|
||||
}
|
||||
|
||||
var stores = make(map[namespace]store)
|
||||
var storesPath = common.DataDir
|
||||
var (
|
||||
stores = make(map[namespace]store)
|
||||
storesPath = common.DataDir
|
||||
)
|
||||
|
||||
func init() {
|
||||
task.OnProgramExit("save_stores", func() {
|
||||
if err := save(); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to save stores")
|
||||
log.Error().Err(err).Msg("failed to save stores")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -54,20 +55,20 @@ func loadNS[T store](ns namespace) T {
|
|||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logging.Err(err).
|
||||
log.Err(err).
|
||||
Str("path", path).
|
||||
Msg("failed to load store")
|
||||
}
|
||||
} else {
|
||||
defer file.Close()
|
||||
if err := json.NewDecoder(file).Decode(&store); err != nil {
|
||||
logging.Err(err).
|
||||
log.Err(err).
|
||||
Str("path", path).
|
||||
Msg("failed to load store")
|
||||
}
|
||||
}
|
||||
stores[ns] = store
|
||||
logging.Debug().
|
||||
log.Debug().
|
||||
Str("namespace", string(ns)).
|
||||
Str("path", path).
|
||||
Msg("loaded store")
|
||||
|
@ -77,7 +78,7 @@ func loadNS[T store](ns namespace) T {
|
|||
func save() error {
|
||||
errs := gperr.NewBuilder("failed to save data stores")
|
||||
for ns, store := range stores {
|
||||
if err := utils.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
|
||||
if err := serialization.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +87,7 @@ func save() error {
|
|||
|
||||
func Store[VT any](namespace namespace) MapStore[VT] {
|
||||
if _, ok := stores[namespace]; ok {
|
||||
logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
||||
log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
||||
}
|
||||
store := loadNS[*MapStore[VT]](namespace)
|
||||
stores[namespace] = store
|
||||
|
@ -95,7 +96,7 @@ func Store[VT any](namespace namespace) MapStore[VT] {
|
|||
|
||||
func Object[Ptr Initializer](namespace namespace) Ptr {
|
||||
if _, ok := stores[namespace]; ok {
|
||||
logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
||||
log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
||||
}
|
||||
obj := loadNS[*ObjectStore[Ptr]](namespace)
|
||||
stores[namespace] = obj
|
||||
|
@ -117,7 +118,7 @@ func (s *MapStore[VT]) UnmarshalJSON(data []byte) error {
|
|||
}
|
||||
s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp)))
|
||||
for k, v := range tmp {
|
||||
s.Map.Store(k, v)
|
||||
s.Store(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
|
@ -83,6 +83,9 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if io == nil {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return NewAccessLoggerWithIO(parent, io, cfg), nil
|
||||
}
|
||||
|
||||
|
@ -120,7 +123,7 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any
|
|||
bufSize: MinBufferSize,
|
||||
lineBufPool: synk.NewBytesPool(),
|
||||
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
|
||||
logger: logging.With().Str("file", writer.Name()).Logger(),
|
||||
logger: log.With().Str("file", writer.Name()).Logger(),
|
||||
}
|
||||
|
||||
l.supportRotate = unwrap[supportRotate](writer)
|
||||
|
@ -181,7 +184,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
|
|||
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
|
||||
line := l.lineBufPool.Get()
|
||||
defer l.lineBufPool.Put(line)
|
||||
line = l.ACLFormatter.AppendACLLog(line, info, blocked)
|
||||
line = l.AppendACLLog(line, info, blocked)
|
||||
if line[len(line)-1] != '\n' {
|
||||
line = append(line, '\n')
|
||||
}
|
||||
|
@ -194,7 +197,7 @@ func (l *AccessLogger) ShouldRotate() bool {
|
|||
|
||||
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
|
||||
if !l.ShouldRotate() {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
l.writer.Flush()
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -126,6 +126,6 @@ func DefaultACLLoggerConfig() *ACLLoggerConfig {
|
|||
}
|
||||
|
||||
func init() {
|
||||
utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
|
||||
utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
|
||||
serialization.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
|
||||
serialization.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
. "github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) {
|
|||
expect.NoError(t, err)
|
||||
|
||||
var config RequestLoggerConfig
|
||||
err = utils.MapUnmarshalValidate(parsed, &config)
|
||||
err = serialization.MapUnmarshalValidate(parsed, &config)
|
||||
expect.NoError(t, err)
|
||||
|
||||
expect.Equal(t, config.Format, FormatCombined)
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -35,20 +35,19 @@ func newFileIO(path string) (SupportRotate, error) {
|
|||
if opened, ok := openedFiles[path]; ok {
|
||||
opened.refCount.Add()
|
||||
return opened, nil
|
||||
} else {
|
||||
// cannot open as O_APPEND as we need Seek and WriteAt
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access log open error: %w", err)
|
||||
}
|
||||
if _, err := f.Seek(0, io.SeekEnd); err != nil {
|
||||
return nil, fmt.Errorf("access log seek error: %w", err)
|
||||
}
|
||||
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
|
||||
openedFiles[path] = file
|
||||
go file.closeOnZero()
|
||||
}
|
||||
|
||||
// cannot open as O_APPEND as we need Seek and WriteAt
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("access log open error: %w", err)
|
||||
}
|
||||
if _, err := f.Seek(0, io.SeekEnd); err != nil {
|
||||
return nil, fmt.Errorf("access log seek error: %w", err)
|
||||
}
|
||||
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
|
||||
openedFiles[path] = file
|
||||
go file.closeOnZero()
|
||||
return file, nil
|
||||
}
|
||||
|
||||
|
@ -90,7 +89,7 @@ func (f *File) Close() error {
|
|||
}
|
||||
|
||||
func (f *File) closeOnZero() {
|
||||
defer logging.Debug().
|
||||
defer log.Debug().
|
||||
Str("path", f.path).
|
||||
Msg("access log closed")
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
//nolint:zerologlint
|
||||
package logging
|
||||
|
||||
import (
|
||||
|
@ -10,6 +9,8 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
||||
zerologlog "github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -61,22 +62,6 @@ func InitLogger(out ...io.Writer) {
|
|||
log.SetOutput(writer)
|
||||
log.SetPrefix("")
|
||||
log.SetFlags(0)
|
||||
zerolog.TimeFieldFormat = timeFmt
|
||||
zerologlog.Logger = logger
|
||||
}
|
||||
|
||||
func DiscardLogger() { zerolog.SetGlobalLevel(zerolog.Disabled) }
|
||||
|
||||
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
|
||||
|
||||
func GetLogger() *zerolog.Logger { return &logger }
|
||||
func With() zerolog.Context { return logger.With() }
|
||||
|
||||
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
|
||||
|
||||
func Info() *zerolog.Event { return logger.Info() }
|
||||
func Warn() *zerolog.Event { return logger.Warn() }
|
||||
func Error() *zerolog.Event { return logger.Error() }
|
||||
func Err(err error) *zerolog.Event { return logger.Err(err) }
|
||||
func Debug() *zerolog.Event { return logger.Debug() }
|
||||
func Fatal() *zerolog.Event { return logger.Fatal() }
|
||||
func Panic() *zerolog.Event { return logger.Panic() }
|
||||
func Trace() *zerolog.Event { return logger.Trace() }
|
||||
|
|
|
@ -2,8 +2,8 @@ package maxmind
|
|||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -28,6 +28,6 @@ func (cfg *Config) Validate() gperr.Error {
|
|||
}
|
||||
|
||||
func (cfg *Config) Logger() *zerolog.Logger {
|
||||
l := logging.With().Str("database", string(cfg.Database)).Logger()
|
||||
l := log.With().Str("database", string(cfg.Database)).Logger()
|
||||
return &l
|
||||
}
|
||||
|
|
|
@ -10,8 +10,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||
)
|
||||
|
@ -47,7 +47,7 @@ var initDataDirOnce sync.Once
|
|||
|
||||
func initDataDir() {
|
||||
if err := os.MkdirAll(saveBaseDir, 0o755); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to create metrics data directory")
|
||||
log.Error().Err(err).Msg("failed to create metrics data directory")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ func NewPoller[T any, AggregateT json.Marshaler](
|
|||
}
|
||||
|
||||
func (p *Poller[T, AggregateT]) savePath() string {
|
||||
return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name))
|
||||
return filepath.Join(saveBaseDir, p.name+".json")
|
||||
}
|
||||
|
||||
func (p *Poller[T, AggregateT]) load() error {
|
||||
|
@ -135,13 +135,14 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
|
|||
|
||||
func (p *Poller[T, AggregateT]) Start() {
|
||||
t := task.RootTask("poller." + p.name)
|
||||
l := log.With().Str("name", p.name).Logger()
|
||||
err := p.load()
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logging.Error().Err(err).Msgf("failed to load last metrics data for %s", p.name)
|
||||
l.Err(err).Msg("failed to load last metrics data")
|
||||
}
|
||||
} else {
|
||||
logging.Debug().Msgf("Loaded last metrics data for %s, %d entries", p.name, p.period.Total())
|
||||
l.Debug().Int("entries", p.period.Total()).Msgf("Loaded last metrics data")
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
@ -154,11 +155,13 @@ func (p *Poller[T, AggregateT]) Start() {
|
|||
gatherErrsTicker.Stop()
|
||||
saveTicker.Stop()
|
||||
|
||||
p.save()
|
||||
if err := p.save(); err != nil {
|
||||
l.Err(err).Msg("failed to save metrics data")
|
||||
}
|
||||
t.Finish(nil)
|
||||
}()
|
||||
|
||||
logging.Debug().Msgf("Starting poller %s with interval %s", p.name, pollInterval)
|
||||
l.Debug().Dur("interval", pollInterval).Msg("Starting poller")
|
||||
|
||||
p.pollWithTimeout(t.Context())
|
||||
|
||||
|
@ -176,7 +179,7 @@ func (p *Poller[T, AggregateT]) Start() {
|
|||
case <-gatherErrsTicker.C:
|
||||
errs, ok := p.gatherErrs()
|
||||
if ok {
|
||||
logging.Error().Msg(errs)
|
||||
log.Error().Msg(errs)
|
||||
}
|
||||
p.clearErrs()
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/disk"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
|
@ -16,7 +17,6 @@ import (
|
|||
"github.com/shirou/gopsutil/v4/warning"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/metrics/period"
|
||||
)
|
||||
|
||||
|
@ -130,7 +130,7 @@ func getSystemInfo(ctx context.Context, lastResult *SystemInfo) (*SystemInfo, er
|
|||
}
|
||||
})
|
||||
if allWarnings.HasError() {
|
||||
logging.Warn().Msg(allWarnings.String())
|
||||
log.Warn().Msg(allWarnings.String())
|
||||
}
|
||||
if allErrors.HasError() {
|
||||
return nil, allErrors.Error()
|
||||
|
@ -195,7 +195,7 @@ func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInf
|
|||
if len(s.Disks) == 0 {
|
||||
return errs.Error()
|
||||
}
|
||||
logging.Warn().Msg(errs.String())
|
||||
log.Warn().Msg(errs.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||
|
@ -14,9 +14,9 @@ func WriteBody(w http.ResponseWriter, body []byte) {
|
|||
switch {
|
||||
case errors.Is(err, http.ErrHandlerTimeout),
|
||||
errors.Is(err, context.DeadlineExceeded):
|
||||
logging.Err(err).Msg("timeout writing body")
|
||||
log.Err(err).Msg("timeout writing body")
|
||||
default:
|
||||
logging.Err(err).Msg("failed to write body")
|
||||
log.Err(err).Msg("failed to write body")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,11 +9,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func warnNoMatchDomains() {
|
||||
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
||||
log.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
||||
}
|
||||
|
||||
var warnNoMatchDomainOnce sync.Once
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -47,7 +47,7 @@ func New(cfg *Config) *LoadBalancer {
|
|||
lb := &LoadBalancer{
|
||||
Config: new(Config),
|
||||
pool: pool.New[Server]("loadbalancer." + cfg.Link),
|
||||
l: logging.With().Str("name", cfg.Link).Logger(),
|
||||
l: log.With().Str("name", cfg.Link).Logger(),
|
||||
}
|
||||
lb.UpdateConfigIfNeeded(cfg)
|
||||
return lb
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
package types
|
||||
|
||||
type Weight uint16
|
||||
type Weight int
|
||||
|
|
|
@ -4,14 +4,14 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||
return logging.WithLevel(level).
|
||||
Str("remote", r.RemoteAddr).
|
||||
Str("host", r.Host).
|
||||
Str("uri", r.Method+" "+r.RequestURI)
|
||||
return log.WithLevel(level). //nolint:zerologlint
|
||||
Str("remote", r.RemoteAddr).
|
||||
Str("host", r.Host).
|
||||
Str("uri", r.Method+" "+r.RequestURI)
|
||||
}
|
||||
|
||||
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
|
||||
|
|
|
@ -4,8 +4,8 @@ import (
|
|||
"net/http"
|
||||
"text/template"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/auth"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
|
||||
_ "embed"
|
||||
|
@ -55,7 +55,7 @@ func PreRequest(p Provider, w http.ResponseWriter, r *http.Request) (proceed boo
|
|||
"FormHTML": p.FormHTML(),
|
||||
})
|
||||
if err != nil {
|
||||
logging.Error().Err(err).Msg("failed to execute captcha page")
|
||||
log.Error().Err(err).Msg("failed to execute captcha page")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/go-playground/validator/v10"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
|
||||
serialization.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
|
||||
statusCode := fl.Field().Int()
|
||||
return gphttp.IsStatusCodeValid(int(statusCode))
|
||||
})
|
||||
|
@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
|||
ipStr = r.RemoteAddr
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
|
||||
for _, cidr := range wl.Allow {
|
||||
if cidr.Contains(ip) {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||
allow = true
|
||||
|
@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
|||
}
|
||||
if !allow {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
|
||||
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
|
|||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
ExpectError(t, serialization.ErrValidationError, err)
|
||||
})
|
||||
t.Run("invalid cidr", func(t *testing.T) {
|
||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||
|
@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
|
|||
"status_code": 600,
|
||||
"message": testMessage,
|
||||
})
|
||||
ExpectError(t, utils.ErrValidationError, err)
|
||||
ExpectError(t, serialization.ErrValidationError, err)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -9,8 +10,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
@ -89,21 +90,29 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
|||
)
|
||||
if err != nil {
|
||||
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
|
||||
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
log.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||
return nil
|
||||
}
|
||||
if len(cfCIDRs) == 0 {
|
||||
logging.Warn().Msg("cloudflare CIDR range is empty")
|
||||
log.Warn().Msg("cloudflare CIDR range is empty")
|
||||
}
|
||||
}
|
||||
|
||||
cfCIDRsLastUpdate.Store(time.Now())
|
||||
logging.Info().Msg("cloudflare CIDR range updated")
|
||||
log.Info().Msg("cloudflare CIDR range updated")
|
||||
return
|
||||
}
|
||||
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||
resp, err := http.Get(endpoint)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
||||
|
@ -32,7 +32,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
|
|||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||
if ok {
|
||||
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
log.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||
|
@ -40,7 +40,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
|
|||
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||
} else {
|
||||
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
log.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
|||
filename := path[len(StaticFilePathPrefix):]
|
||||
file, ok := errorpage.GetStaticFile(filename)
|
||||
if !ok {
|
||||
logging.Error().Msg("unable to load resource " + filename)
|
||||
log.Error().Msg("unable to load resource " + filename)
|
||||
return false
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
|
@ -68,10 +68,10 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
|||
case ".css":
|
||||
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
||||
default:
|
||||
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
log.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||
}
|
||||
if _, err := w.Write(file); err != nil {
|
||||
logging.Err(err).Msg("unable to write resource " + filename)
|
||||
log.Err(err).Msg("unable to write resource " + filename)
|
||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||
}
|
||||
return true
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"path"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
|
@ -48,7 +48,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
|||
func loadContent() {
|
||||
files, err := U.ListFiles(errPagesBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list error page resources")
|
||||
log.Err(err).Msg("failed to list error page resources")
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
|
@ -57,11 +57,11 @@ func loadContent() {
|
|||
}
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||
log.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||
continue
|
||||
}
|
||||
file = path.Base(file)
|
||||
logging.Info().Msgf("error page resource %s loaded", file)
|
||||
log.Info().Msgf("error page resource %s loaded", file)
|
||||
fileContentMap.Store(file, content)
|
||||
}
|
||||
}
|
||||
|
@ -83,9 +83,9 @@ func watchDir() {
|
|||
loadContent()
|
||||
case events.ActionFileDeleted:
|
||||
fileContentMap.Delete(filename)
|
||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
||||
log.Warn().Msgf("error page resource %s deleted", filename)
|
||||
case events.ActionFileRenamed:
|
||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
||||
log.Warn().Msgf("error page resource %s deleted", filename)
|
||||
fileContentMap.Delete(filename)
|
||||
loadContent()
|
||||
}
|
||||
|
|
|
@ -8,11 +8,11 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -87,7 +87,7 @@ func NewMiddleware[ImplType any]() *Middleware {
|
|||
func (m *Middleware) enableTrace() {
|
||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||
tracer.enableTrace()
|
||||
logging.Trace().Msgf("middleware %s enabled trace", m.name)
|
||||
log.Trace().Msgf("middleware %s enabled trace", m.name)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,14 +118,14 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
|||
"priority": optsRaw["priority"],
|
||||
"bypass": optsRaw["bypass"],
|
||||
}
|
||||
if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
||||
if err := serialization.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
||||
return err
|
||||
}
|
||||
optsRaw = maps.Clone(optsRaw)
|
||||
for k := range commonOpts {
|
||||
delete(optsRaw, k)
|
||||
}
|
||||
return utils.MapUnmarshalValidate(optsRaw, m.impl)
|
||||
return serialization.MapUnmarshalValidate(optsRaw, m.impl)
|
||||
}
|
||||
|
||||
func (m *Middleware) finalize() error {
|
||||
|
|
|
@ -3,9 +3,9 @@ package middleware
|
|||
import (
|
||||
"path"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
@ -59,7 +59,7 @@ func LoadComposeFiles() {
|
|||
errs := gperr.NewBuilder("middleware compile errors")
|
||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||
if err != nil {
|
||||
logging.Err(err).Msg("failed to list middleware definitions")
|
||||
log.Err(err).Msg("failed to list middleware definitions")
|
||||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
|
@ -75,7 +75,7 @@ func LoadComposeFiles() {
|
|||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
logging.Info().
|
||||
log.Info().
|
||||
Str("src", path.Base(defFile)).
|
||||
Str("name", name).
|
||||
Msg("middleware loaded")
|
||||
|
@ -94,7 +94,7 @@ func LoadComposeFiles() {
|
|||
continue
|
||||
}
|
||||
allMiddlewares[name] = m
|
||||
logging.Info().
|
||||
log.Info().
|
||||
Str("src", path.Base(defFile)).
|
||||
Str("name", name).
|
||||
Msg("middleware loaded")
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
|
@ -138,7 +138,7 @@ func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper
|
|||
panic("nil transport")
|
||||
}
|
||||
rp := &ReverseProxy{
|
||||
Logger: logging.With().Str("name", name).Logger(),
|
||||
Logger: log.With().Str("name", name).Logger(),
|
||||
Transport: transport,
|
||||
TargetName: name,
|
||||
TargetURL: target,
|
||||
|
@ -173,17 +173,17 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
|||
case errors.Is(err, context.Canceled),
|
||||
errors.Is(err, io.EOF),
|
||||
errors.Is(err, context.DeadlineExceeded):
|
||||
logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||
log.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||
default:
|
||||
var recordErr tls.RecordHeaderError
|
||||
if errors.As(err, &recordErr) {
|
||||
logging.Error().
|
||||
log.Error().
|
||||
Str("url", reqURL).
|
||||
Msgf(`scheme was likely misconfigured as https,
|
||||
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
|
||||
logging.Err(err).Msg("underlying error")
|
||||
log.Err(err).Msg("underlying error")
|
||||
} else {
|
||||
logging.Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||
log.Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
transport := p.Transport
|
||||
|
||||
ctx := req.Context()
|
||||
/* trunk-ignore(golangci-lint/revive) */
|
||||
if ctx.Done() != nil {
|
||||
// CloseNotifier predates context.Context, and has been
|
||||
// entirely superseded by it. If the request contains
|
||||
|
@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
return nil
|
||||
},
|
||||
}
|
||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
|
||||
|
@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
|||
res.Header = rw.Header()
|
||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||
if err := res.Write(brw); err != nil {
|
||||
/* trunk-ignore(golangci-lint/errorlint) */
|
||||
//nolint:errorlint
|
||||
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
|
||||
return
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
/* trunk-ignore(golangci-lint/errorlint) */
|
||||
//nolint:errorlint
|
||||
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
|
||||
return
|
||||
}
|
||||
|
||||
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
|
||||
/* trunk-ignore(golangci-lint/errcheck) */
|
||||
//nolint:errcheck
|
||||
bdp.Start()
|
||||
}
|
||||
|
|
@ -9,14 +9,14 @@ import (
|
|||
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/acl"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
)
|
||||
|
||||
type CertProvider interface {
|
||||
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
|
@ -53,7 +53,7 @@ func StartServer(parent task.Parent, opt Options) (s *Server) {
|
|||
func NewServer(opt Options) (s *Server) {
|
||||
var httpSer, httpsSer *http.Server
|
||||
|
||||
logger := logging.With().Str("server", opt.Name).Logger()
|
||||
logger := log.With().Str("server", opt.Name).Logger()
|
||||
|
||||
certAvailable := false
|
||||
if opt.CertProvider != nil {
|
||||
|
|
|
@ -2,7 +2,7 @@ package notif
|
|||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type NotificationConfig struct {
|
||||
|
@ -46,5 +46,5 @@ func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err gperr.Error)
|
|||
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
|
||||
}
|
||||
|
||||
return utils.MapUnmarshalValidate(m, cfg.Provider)
|
||||
return serialization.MapUnmarshalValidate(m, cfg.Provider)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -181,7 +181,7 @@ func TestNotificationConfig(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
var cfg NotificationConfig
|
||||
provider := tt.cfg["provider"]
|
||||
err := utils.MapUnmarshalValidate(tt.cfg, &cfg)
|
||||
err := serialization.MapUnmarshalValidate(tt.cfg, &cfg)
|
||||
if tt.wantErr {
|
||||
ExpectHasError(t, err)
|
||||
} else {
|
||||
|
|
|
@ -8,14 +8,14 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
)
|
||||
|
||||
type (
|
||||
Provider interface {
|
||||
utils.CustomValidator
|
||||
serialization.CustomValidator
|
||||
|
||||
GetName() string
|
||||
GetURL() string
|
||||
|
@ -73,7 +73,7 @@ func (msg *LogMessage) notify(ctx context.Context, provider Provider) error {
|
|||
switch resp.StatusCode {
|
||||
case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent:
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logging.Debug().
|
||||
log.Debug().
|
||||
Str("provider", provider.GetName()).
|
||||
Str("url", provider.GetURL()).
|
||||
Str("status", resp.Status).
|
||||
|
|
|
@ -7,12 +7,12 @@ import (
|
|||
"github.com/docker/docker/client"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"github.com/yusing/go-proxy/internal/watcher"
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ func DockerProviderImpl(name, dockerHost string) ProviderImpl {
|
|||
return &DockerProvider{
|
||||
name,
|
||||
dockerHost,
|
||||
logging.With().Str("type", "docker").Str("name", name).Logger(),
|
||||
log.With().Str("type", "docker").Str("name", name).Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,7 +106,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
|||
// Always non-nil.
|
||||
func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) {
|
||||
if !container.IsExplicit && p.IsExplicitOnly() {
|
||||
return nil, nil
|
||||
return make(route.Routes, 0), nil
|
||||
}
|
||||
|
||||
routes := make(route.Routes, len(container.Aliases))
|
||||
|
@ -180,7 +180,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
|
|||
}
|
||||
|
||||
// deserialize map into entry object
|
||||
err := U.MapUnmarshalValidate(entryMap, r)
|
||||
err := serialization.MapUnmarshalValidate(entryMap, r)
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(alias))
|
||||
} else {
|
||||
|
@ -189,7 +189,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
|
|||
}
|
||||
if wildcardProps != nil {
|
||||
for _, re := range routes {
|
||||
if err := U.MapUnmarshalValidate(wildcardProps, re); err != nil {
|
||||
if err := serialization.MapUnmarshalValidate(wildcardProps, re); err != nil {
|
||||
errs.Add(err.Subject(docker.WildcardAlias))
|
||||
break
|
||||
}
|
||||
|
|
|
@ -6,11 +6,11 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/route"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
W "github.com/yusing/go-proxy/internal/watcher"
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
|
|||
impl := &FileProvider{
|
||||
fileName: filename,
|
||||
path: path.Join(common.ConfigBasePath, filename),
|
||||
l: logging.With().Str("type", "file").Str("name", filename).Logger(),
|
||||
l: log.With().Str("type", "file").Str("name", filename).Logger(),
|
||||
}
|
||||
_, err := os.Stat(impl.path)
|
||||
if err != nil {
|
||||
|
@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
|
|||
}
|
||||
|
||||
func validate(data []byte) (routes route.Routes, err gperr.Error) {
|
||||
err = utils.UnmarshalValidateYAML(data, &routes)
|
||||
err = serialization.UnmarshalValidateYAML(data, &routes)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/homepage"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
netutils "github.com/yusing/go-proxy/internal/net"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxmox"
|
||||
|
@ -116,7 +116,7 @@ func (r *Route) Validate() gperr.Error {
|
|||
Subject(containerName)
|
||||
}
|
||||
|
||||
l := logging.With().Str("container", containerName).Logger()
|
||||
l := log.With().Str("container", containerName).Logger()
|
||||
|
||||
l.Info().Msg("checking if container is running")
|
||||
running, err := node.LXCIsRunning(ctx, vmid)
|
||||
|
|
|
@ -3,7 +3,7 @@ package rules
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,7 @@ func TestParseRule(t *testing.T) {
|
|||
var rules struct {
|
||||
Rules Rules
|
||||
}
|
||||
err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules)
|
||||
err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, len(rules.Rules), len(test))
|
||||
ExpectEqual(t, rules.Rules[0].Name, "test")
|
||||
|
|
|
@ -5,9 +5,9 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/idlewatcher"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -32,7 +32,7 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
|
|||
// TODO: support non-coherent scheme
|
||||
return &StreamRoute{
|
||||
Route: base,
|
||||
l: logging.With().
|
||||
l: log.With().
|
||||
Str("type", string(base.Scheme)).
|
||||
Str("name", base.Name()).
|
||||
Logger(),
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
. "github.com/yusing/go-proxy/internal/route"
|
||||
route "github.com/yusing/go-proxy/internal/route/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/serialization"
|
||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
@ -40,7 +40,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := Route{}
|
||||
tt.input["host"] = "internal"
|
||||
err := utils.MapUnmarshalValidate(tt.input, &cfg)
|
||||
err := serialization.MapUnmarshalValidate(tt.input, &cfg)
|
||||
if err != nil {
|
||||
expect.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
@ -88,7 +88,7 @@ func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) {
|
|||
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
|
||||
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
|
||||
if err == nil {
|
||||
logging.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
|
||||
log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ func (conn *UDPConn) read() (err error) {
|
|||
conn.buf.oobn = 0
|
||||
}
|
||||
if err == nil {
|
||||
logging.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||
log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ func (conn *UDPConn) read() (err error) {
|
|||
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
|
||||
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
|
||||
if err == nil {
|
||||
logging.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
|
||||
log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -120,12 +120,12 @@ func (conn *UDPConn) write() (err error) {
|
|||
case *net.UDPConn:
|
||||
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
|
||||
if err == nil {
|
||||
logging.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||
log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||
}
|
||||
default:
|
||||
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
|
||||
if err == nil {
|
||||
logging.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
|
||||
log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package utils
|
||||
package serialization
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
@ -12,7 +12,9 @@ import (
|
|||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
@ -40,14 +42,14 @@ var (
|
|||
|
||||
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
||||
|
||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
||||
var defaultValues = xsync.NewMapOf[reflect.Type, func() any]()
|
||||
|
||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||
t := reflect.TypeFor[T]()
|
||||
if t.Kind() == reflect.Ptr {
|
||||
panic("pointer of pointer")
|
||||
}
|
||||
if defaultValues.Has(t) {
|
||||
if _, ok := defaultValues.Load(t); ok {
|
||||
panic("default value for " + t.String() + " already registered")
|
||||
}
|
||||
defaultValues.Store(t, func() any { return factory() })
|
||||
|
@ -259,7 +261,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool)
|
|||
errs.Add(err.Subject(k))
|
||||
}
|
||||
} else {
|
||||
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping))))
|
||||
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(utils.NearestField(k, mapping))))
|
||||
}
|
||||
}
|
||||
if hasValidateTag && checkValidateTag {
|
|
@ -1,4 +1,4 @@
|
|||
package utils
|
||||
package serialization
|
||||
|
||||
import (
|
||||
"reflect"
|
|
@ -1,4 +1,4 @@
|
|||
package utils
|
||||
package serialization
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
|
@ -7,9 +7,9 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
|
@ -116,14 +116,14 @@ func (t *Task) Finish(reason any) {
|
|||
func (t *Task) finish(reason any) {
|
||||
t.cancel(fmtCause(reason))
|
||||
if !waitWithTimeout(t.childrenDone) {
|
||||
logging.Debug().
|
||||
log.Debug().
|
||||
Str("task", t.name).
|
||||
Strs("subtasks", t.listChildren()).
|
||||
Msg("Timeout waiting for subtasks to finish")
|
||||
}
|
||||
go t.runCallbacks()
|
||||
if !waitWithTimeout(t.callbacksDone) {
|
||||
logging.Debug().
|
||||
log.Debug().
|
||||
Str("task", t.name).
|
||||
Strs("callbacks", t.listCallbacks()).
|
||||
Msg("Timeout waiting for callbacks to finish")
|
||||
|
@ -134,7 +134,7 @@ func (t *Task) finish(reason any) {
|
|||
}
|
||||
t.parent.subChildCount()
|
||||
allTasks.Remove(t)
|
||||
logging.Trace().Msg("task " + t.name + " finished")
|
||||
log.Trace().Msg("task " + t.name + " finished")
|
||||
}
|
||||
|
||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||
|
@ -166,7 +166,7 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
|||
}()
|
||||
}
|
||||
|
||||
logging.Trace().Msg("task " + child.name + " started")
|
||||
log.Trace().Msg("task " + child.name + " started")
|
||||
return child
|
||||
}
|
||||
|
||||
|
@ -189,7 +189,7 @@ func (t *Task) MarshalText() ([]byte, error) {
|
|||
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logging.Error().
|
||||
log.Error().
|
||||
Interface("err", err).
|
||||
Msg("panic in task " + t.name + "." + caller)
|
||||
if common.IsDebug {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
|
@ -68,10 +68,10 @@ func GracefulShutdown(timeout time.Duration) (err error) {
|
|||
case <-after:
|
||||
b, err := json.Marshal(DebugTaskList())
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msg("failed to marshal tasks")
|
||||
log.Warn().Err(err).Msg("failed to marshal tasks")
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
logging.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||
log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
@ -87,6 +87,6 @@ func WaitExit(shutdownTimeout int) {
|
|||
<-sig
|
||||
|
||||
// gracefully shutdown
|
||||
logging.Info().Msg("shutting down")
|
||||
log.Info().Msg("shutting down")
|
||||
_ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout))
|
||||
}
|
||||
|
|
21
internal/utils/go.mod
Normal file
21
internal/utils/go.mod
Normal file
|
@ -0,0 +1,21 @@
|
|||
module github.com/yusing/go-proxy/internal/utils
|
||||
|
||||
go 1.24.3
|
||||
|
||||
require (
|
||||
github.com/goccy/go-yaml v1.17.1
|
||||
github.com/puzpuzpuz/xsync/v4 v4.1.0
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
golang.org/x/text v0.25.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
36
internal/utils/go.sum
Normal file
36
internal/utils/go.sum
Normal file
|
@ -0,0 +1,36 @@
|
|||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY=
|
||||
github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/puzpuzpuz/xsync/v4 v4.1.0 h1:x9eHRl4QhZFIPJ17yl4KKW9xLyVWbb3/Yq4SXpjF71U=
|
||||
github.com/puzpuzpuz/xsync/v4 v4.1.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -8,7 +8,6 @@ import (
|
|||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/synk"
|
||||
)
|
||||
|
||||
|
@ -91,20 +90,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
|
|||
}
|
||||
}
|
||||
|
||||
func (p BidirectionalPipe) Start() gperr.Error {
|
||||
func (p BidirectionalPipe) Start() error {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
b := gperr.NewBuilder("bidirectional pipe error")
|
||||
var srcErr, dstErr error
|
||||
go func() {
|
||||
b.Add(p.pSrcDst.Start())
|
||||
srcErr = p.pSrcDst.Start()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
b.Add(p.pDstSrc.Start())
|
||||
dstErr = p.pDstSrc.Start()
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
return b.Error()
|
||||
return errors.Join(srcErr, dstErr)
|
||||
}
|
||||
|
||||
type httpFlusher interface {
|
||||
|
@ -143,30 +142,18 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
|
|||
wCloser, wCanClose := dst.Writer.(io.Closer)
|
||||
rCloser, rCanClose := src.Reader.(io.Closer)
|
||||
if wCanClose || rCanClose {
|
||||
if src.ctx == dst.ctx {
|
||||
go func() {
|
||||
<-src.ctx.Done()
|
||||
if wCanClose {
|
||||
wCloser.Close()
|
||||
}
|
||||
if rCanClose {
|
||||
rCloser.Close()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
if wCloser != nil {
|
||||
go func() {
|
||||
<-src.ctx.Done()
|
||||
wCloser.Close()
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case <-src.ctx.Done():
|
||||
case <-dst.ctx.Done():
|
||||
}
|
||||
if rCloser != nil {
|
||||
go func() {
|
||||
<-dst.ctx.Done()
|
||||
rCloser.Close()
|
||||
}()
|
||||
if rCanClose {
|
||||
defer rCloser.Close()
|
||||
}
|
||||
}
|
||||
if wCanClose {
|
||||
defer wCloser.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
flusher := getHTTPFlusher(dst.Writer)
|
||||
canFlush := flusher != nil
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -29,12 +29,12 @@ func (p Pool[T]) Name() string {
|
|||
func (p Pool[T]) Add(obj T) {
|
||||
p.checkExists(obj.Key())
|
||||
p.m.Store(obj.Key(), obj)
|
||||
logging.Info().Msgf("%s: added %s", p.name, obj.Name())
|
||||
log.Info().Msgf("%s: added %s", p.name, obj.Name())
|
||||
}
|
||||
|
||||
func (p Pool[T]) Del(obj T) {
|
||||
p.m.Delete(obj.Key())
|
||||
logging.Info().Msgf("%s: removed %s", p.name, obj.Name())
|
||||
log.Info().Msgf("%s: removed %s", p.name, obj.Name())
|
||||
}
|
||||
|
||||
func (p Pool[T]) Get(key string) (T, bool) {
|
||||
|
|
|
@ -5,11 +5,11 @@ package pool
|
|||
import (
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (p Pool[T]) checkExists(key string) {
|
||||
if _, ok := p.m.Load(key); ok {
|
||||
logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
|
||||
log.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,16 +1,35 @@
|
|||
package synk
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type weakBuf = unsafe.Pointer
|
||||
|
||||
func makeWeak(b *[]byte) weakBuf {
|
||||
ptr := runtime_registerWeakPointer(unsafe.Pointer(b))
|
||||
runtime.KeepAlive(ptr)
|
||||
addCleanup(b, addGCed, cap(*b))
|
||||
return weakBuf(ptr)
|
||||
}
|
||||
|
||||
func getBufFromWeak(w weakBuf) []byte {
|
||||
ptr := (*[]byte)(runtime_makeStrongFromWeak(w))
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
//go:linkname runtime_registerWeakPointer weak.runtime_registerWeakPointer
|
||||
func runtime_registerWeakPointer(unsafe.Pointer) unsafe.Pointer
|
||||
|
||||
//go:linkname runtime_makeStrongFromWeak weak.runtime_makeStrongFromWeak
|
||||
func runtime_makeStrongFromWeak(unsafe.Pointer) unsafe.Pointer
|
||||
|
||||
type BytesPool struct {
|
||||
pool chan []byte
|
||||
pool chan weakBuf
|
||||
initSize int
|
||||
}
|
||||
|
||||
|
@ -22,19 +41,15 @@ const (
|
|||
const (
|
||||
InPoolLimit = 32 * mb
|
||||
|
||||
DefaultInitBytes = 32 * kb
|
||||
PoolThreshold = 64 * kb
|
||||
DefaultInitBytes = 4 * kb
|
||||
PoolThreshold = 256 * kb
|
||||
DropThresholdHigh = 4 * mb
|
||||
|
||||
PoolSize = InPoolLimit / PoolThreshold
|
||||
|
||||
CleanupInterval = 5 * time.Second
|
||||
MaxDropsPerCycle = 10
|
||||
MaxChecksPerCycle = 100
|
||||
)
|
||||
|
||||
var bytesPool = &BytesPool{
|
||||
pool: make(chan []byte, PoolSize),
|
||||
pool: make(chan weakBuf, PoolSize),
|
||||
initSize: DefaultInitBytes,
|
||||
}
|
||||
|
||||
|
@ -43,12 +58,18 @@ func NewBytesPool() *BytesPool {
|
|||
}
|
||||
|
||||
func (p *BytesPool) Get() []byte {
|
||||
select {
|
||||
case b := <-p.pool:
|
||||
subInPoolSize(int64(cap(b)))
|
||||
return b
|
||||
default:
|
||||
return make([]byte, 0, p.initSize)
|
||||
for {
|
||||
select {
|
||||
case bWeak := <-p.pool:
|
||||
bPtr := getBufFromWeak(bWeak)
|
||||
if bPtr == nil {
|
||||
continue
|
||||
}
|
||||
addReused(cap(bPtr))
|
||||
return bPtr
|
||||
default:
|
||||
return make([]byte, 0, p.initSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,90 +77,43 @@ func (p *BytesPool) GetSized(size int) []byte {
|
|||
if size <= PoolThreshold {
|
||||
return make([]byte, size)
|
||||
}
|
||||
select {
|
||||
case b := <-p.pool:
|
||||
if size <= cap(b) {
|
||||
subInPoolSize(int64(cap(b)))
|
||||
return b[:size]
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case p.pool <- b:
|
||||
addInPoolSize(int64(cap(b)))
|
||||
case bWeak := <-p.pool:
|
||||
bPtr := getBufFromWeak(bWeak)
|
||||
if bPtr == nil {
|
||||
continue
|
||||
}
|
||||
capB := cap(bPtr)
|
||||
if capB >= size {
|
||||
addReused(capB)
|
||||
return (bPtr)[:size]
|
||||
}
|
||||
select {
|
||||
case p.pool <- bWeak:
|
||||
default:
|
||||
// just drop it
|
||||
}
|
||||
default:
|
||||
}
|
||||
default:
|
||||
return make([]byte, size)
|
||||
}
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func (p *BytesPool) Put(b []byte) {
|
||||
size := cap(b)
|
||||
if size > DropThresholdHigh || poolFull() {
|
||||
if size <= PoolThreshold || size > DropThresholdHigh {
|
||||
return
|
||||
}
|
||||
b = b[:0]
|
||||
w := makeWeak(&b)
|
||||
select {
|
||||
case p.pool <- b:
|
||||
addInPoolSize(int64(size))
|
||||
return
|
||||
case p.pool <- w:
|
||||
default:
|
||||
// just drop it
|
||||
}
|
||||
}
|
||||
|
||||
var inPoolSize int64
|
||||
|
||||
func addInPoolSize(size int64) {
|
||||
atomic.AddInt64(&inPoolSize, size)
|
||||
}
|
||||
|
||||
func subInPoolSize(size int64) {
|
||||
atomic.AddInt64(&inPoolSize, -size)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Periodically drop some buffers to prevent excessive memory usage
|
||||
go func() {
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt)
|
||||
|
||||
cleanupTicker := time.NewTicker(CleanupInterval)
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cleanupTicker.C:
|
||||
dropBuffers()
|
||||
case <-sigCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func poolFull() bool {
|
||||
return atomic.LoadInt64(&inPoolSize) >= InPoolLimit
|
||||
}
|
||||
|
||||
// dropBuffers removes excess buffers from the pool when it grows too large.
|
||||
func dropBuffers() {
|
||||
// Check if pool has more than a threshold of buffers
|
||||
count := 0
|
||||
droppedSize := 0
|
||||
checks := 0
|
||||
for count < MaxDropsPerCycle && checks < MaxChecksPerCycle && atomic.LoadInt64(&inPoolSize) > InPoolLimit*2/3 {
|
||||
select {
|
||||
case b := <-bytesPool.pool:
|
||||
n := cap(b)
|
||||
subInPoolSize(int64(n))
|
||||
droppedSize += n
|
||||
count++
|
||||
default:
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
checks++
|
||||
}
|
||||
if count > 0 {
|
||||
logging.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool")
|
||||
}
|
||||
initPoolStats()
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024}
|
||||
var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024}
|
||||
|
||||
func BenchmarkBytesPool_GetSmall(b *testing.B) {
|
||||
for b.Loop() {
|
||||
|
|
55
internal/utils/synk/pool_debug.go
Normal file
55
internal/utils/synk/pool_debug.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
//go:build !production
|
||||
|
||||
package synk
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var (
|
||||
numReused, sizeReused uint64
|
||||
numGCed, sizeGCed uint64
|
||||
)
|
||||
|
||||
func addReused(size int) {
|
||||
atomic.AddUint64(&numReused, 1)
|
||||
atomic.AddUint64(&sizeReused, uint64(size))
|
||||
}
|
||||
|
||||
func addGCed(size int) {
|
||||
atomic.AddUint64(&numGCed, 1)
|
||||
atomic.AddUint64(&sizeGCed, uint64(size))
|
||||
}
|
||||
|
||||
var addCleanup = runtime.AddCleanup[[]byte, int]
|
||||
|
||||
func initPoolStats() {
|
||||
go func() {
|
||||
statsTicker := time.NewTicker(5 * time.Second)
|
||||
defer statsTicker.Stop()
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sig:
|
||||
return
|
||||
case <-statsTicker.C:
|
||||
log.Info().
|
||||
Uint64("numReused", atomic.LoadUint64(&numReused)).
|
||||
Str("sizeReused", strutils.FormatByteSize(atomic.LoadUint64(&sizeReused))).
|
||||
Uint64("numGCed", atomic.LoadUint64(&numGCed)).
|
||||
Str("sizeGCed", strutils.FormatByteSize(atomic.LoadUint64(&sizeGCed))).
|
||||
Msg("bytes pool stats")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
8
internal/utils/synk/pool_prod.go
Normal file
8
internal/utils/synk/pool_prod.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
//go:build production
|
||||
|
||||
package synk
|
||||
|
||||
func addReused(size int) {}
|
||||
func addGCed(size int) {}
|
||||
func initPoolStats() {}
|
||||
func addCleanup(ptr *[]byte, cleanup func(int), arg int) {}
|
|
@ -2,14 +2,16 @@ package expect
|
|||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
var isTest = strings.HasSuffix(os.Args[0], ".test")
|
||||
|
||||
func init() {
|
||||
if common.IsTest {
|
||||
if isTest {
|
||||
// force verbose output
|
||||
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
||||
}
|
||||
|
|
|
@ -1,19 +1,11 @@
|
|||
package expect
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if common.IsTest {
|
||||
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -3,7 +3,6 @@ package utils
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
|
@ -38,8 +37,6 @@ func init() {
|
|||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-task.RootContext().Done():
|
||||
return
|
||||
case <-timeNowTicker.C:
|
||||
shouldCallTimeNow.Store(true)
|
||||
}
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
@ -41,13 +41,13 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
|
|||
//! subdirectories are not watched
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
logging.Panic().Err(err).Msg("unable to create fs watcher")
|
||||
log.Panic().Err(err).Msg("unable to create fs watcher")
|
||||
}
|
||||
if err = w.Add(dirPath); err != nil {
|
||||
logging.Panic().Err(err).Msg("unable to create fs watcher")
|
||||
log.Panic().Err(err).Msg("unable to create fs watcher")
|
||||
}
|
||||
helper := &DirWatcher{
|
||||
Logger: logging.With().
|
||||
Logger: log.With().
|
||||
Str("type", "dir").
|
||||
Str("path", dirPath).
|
||||
Logger(),
|
||||
|
|
|
@ -8,9 +8,9 @@ import (
|
|||
docker_events "github.com/docker/docker/api/types/events"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||
)
|
||||
|
||||
|
@ -81,7 +81,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
|
|||
}()
|
||||
|
||||
cEventCh, cErrCh := client.Events(ctx, options)
|
||||
defer logging.Debug().Str("host", client.Address()).Msg("docker watcher closed")
|
||||
defer log.Debug().Str("host", client.Address()).Msg("docker watcher closed")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
@ -153,7 +153,7 @@ func checkConnection(ctx context.Context, client *docker.SharedClient) bool {
|
|||
defer cancel()
|
||||
err := client.CheckConnection(ctx)
|
||||
if err != nil {
|
||||
logging.Debug().Err(err).Msg("docker watcher: connection failed")
|
||||
log.Debug().Err(err).Msg("docker watcher: connection failed")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
|
|
@ -7,9 +7,9 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/docker"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/notif"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
|
@ -48,7 +48,7 @@ func NewMonitor(r routes.Route) health.HealthMonCheck {
|
|||
case routes.StreamRoute:
|
||||
mon = NewRawHealthMonitor(&r.TargetURL().URL, r.HealthCheckConfig())
|
||||
default:
|
||||
logging.Panic().Msgf("unexpected route type: %T", r)
|
||||
log.Panic().Msgf("unexpected route type: %T", r)
|
||||
}
|
||||
}
|
||||
if r.IsDocker() {
|
||||
|
@ -91,7 +91,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
|
|||
mon.task = parent.Subtask("health_monitor", true)
|
||||
|
||||
go func() {
|
||||
logger := logging.With().Str("name", mon.service).Logger()
|
||||
logger := log.With().Str("name", mon.service).Logger()
|
||||
|
||||
defer func() {
|
||||
if mon.status.Load() != health.StatusError {
|
||||
|
@ -221,7 +221,7 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (mon *monitor) checkUpdateHealth() error {
|
||||
logger := logging.With().Str("name", mon.Name()).Logger()
|
||||
logger := log.With().Str("name", mon.Name()).Logger()
|
||||
result, err := mon.checkHealth()
|
||||
|
||||
var lastStatus health.Status
|
||||
|
|
|
@ -18,7 +18,7 @@ func GetLastVersion() Version {
|
|||
|
||||
func GetVersionHTTPHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(GetVersion().String()))
|
||||
fmt.Fprint(w, GetVersion().String())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,16 +34,16 @@ func init() {
|
|||
// lastVersion = ParseVersion(lastVersionStr)
|
||||
// }
|
||||
// if err != nil && !os.IsNotExist(err) {
|
||||
// logging.Warn().Err(err).Msg("failed to read version file")
|
||||
// log.Warn().Err(err).Msg("failed to read version file")
|
||||
// return
|
||||
// }
|
||||
// if err := f.Truncate(0); err != nil {
|
||||
// logging.Warn().Err(err).Msg("failed to truncate version file")
|
||||
// log.Warn().Err(err).Msg("failed to truncate version file")
|
||||
// return
|
||||
// }
|
||||
// _, err = f.WriteString(version)
|
||||
// if err != nil {
|
||||
// logging.Warn().Err(err).Msg("failed to save version file")
|
||||
// log.Warn().Err(err).Msg("failed to save version file")
|
||||
// return
|
||||
// }
|
||||
}
|
||||
|
|
|
@ -2,4 +2,19 @@ module github.com/yusing/go-proxy/socketproxy
|
|||
|
||||
go 1.24.3
|
||||
|
||||
require github.com/gorilla/mux v1.8.1
|
||||
replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
|
||||
|
||||
require (
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/yusing/go-proxy/internal/utils v0.0.0-00010101000000-000000000000
|
||||
golang.org/x/net v0.40.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
)
|
||||
|
|
|
@ -1,2 +1,34 @@
|
|||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
|
||||
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -4,30 +4,33 @@ import (
|
|||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/yusing/go-proxy/socketproxy/pkg/reverseproxy"
|
||||
)
|
||||
|
||||
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
|
||||
|
||||
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "unix", DockerSocket)
|
||||
func dialDockerSocket(socket string) func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "unix", socket)
|
||||
}
|
||||
}
|
||||
|
||||
var DockerSocketHandler = dockerSocketHandler
|
||||
|
||||
func dockerSocketHandler() http.HandlerFunc {
|
||||
rp := &httputil.ReverseProxy{
|
||||
func dockerSocketHandler(socket string) http.HandlerFunc {
|
||||
rp := &reverseproxy.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = "api.moby.localhost"
|
||||
req.RequestURI = req.URL.String()
|
||||
},
|
||||
Transport: &http.Transport{
|
||||
DialContext: dialDockerSocket,
|
||||
DialContext: dialDockerSocket(socket),
|
||||
DisableCompression: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -41,7 +44,7 @@ func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
|
|||
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
|
||||
func NewHandler() http.Handler {
|
||||
r := mux.NewRouter()
|
||||
socketHandler := DockerSocketHandler()
|
||||
socketHandler := DockerSocketHandler(DockerSocket)
|
||||
|
||||
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
|
||||
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
. "github.com/yusing/go-proxy/socketproxy/pkg"
|
||||
)
|
||||
|
||||
func mockDockerSocketHandler() http.HandlerFunc {
|
||||
func mockDockerSocketHandler(_ string) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("mock docker response"))
|
||||
|
|
367
socket-proxy/pkg/reverseproxy/reverse_proxy.go
Normal file
367
socket-proxy/pkg/reverseproxy/reverse_proxy.go
Normal file
|
@ -0,0 +1,367 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// License URL: https://cs.opensource.google/go/go/+/master:LICENSE
|
||||
|
||||
// HTTP reverse proxy handler
|
||||
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||
// sends it to another server, proxying the response back to the
|
||||
// client.
|
||||
//
|
||||
// 1xx responses are forwarded to the client if the underlying
|
||||
// transport supports ClientTrace.Got1xxResponse.
|
||||
type ReverseProxy struct {
|
||||
// Director is a function which modifies
|
||||
// the request into a new request to be sent
|
||||
// using Transport. Its response is then copied
|
||||
// back to the original client unmodified.
|
||||
// Director must not access the provided Request
|
||||
// after returning.
|
||||
//
|
||||
// By default, the X-Forwarded-For header is set to the
|
||||
// value of the client IP address. If an X-Forwarded-For
|
||||
// header already exists, the client IP is appended to the
|
||||
// existing values. As a special case, if the header
|
||||
// exists in the Request.Header map but has a nil value
|
||||
// (such as when set by the Director func), the X-Forwarded-For
|
||||
// header is not modified.
|
||||
//
|
||||
// To prevent IP spoofing, be sure to delete any pre-existing
|
||||
// X-Forwarded-For header coming from the client or
|
||||
// an untrusted proxy.
|
||||
//
|
||||
// Hop-by-hop headers are removed from the request after
|
||||
// Director returns, which can remove headers added by
|
||||
// Director. Use a Rewrite function instead to ensure
|
||||
// modifications to the request are preserved.
|
||||
//
|
||||
// Unparsable query parameters are removed from the outbound
|
||||
// request if Request.Form is set after Director returns.
|
||||
//
|
||||
// At most one of Rewrite or Director may be set.
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// ErrorLog specifies an optional logger for errors
|
||||
// that occur when attempting to proxy the request.
|
||||
// If nil, logging is done via the log package's standard logger.
|
||||
ErrorLog *log.Logger
|
||||
|
||||
// ErrorHandler is an optional function that handles errors
|
||||
// reaching the backend or errors from ModifyResponse.
|
||||
//
|
||||
// If nil, the default is to log the provided error and return
|
||||
// a 502 Status Bad Gateway response.
|
||||
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
p.logf("http: proxy error: %v", err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
|
||||
if p.ErrorHandler != nil {
|
||||
return p.ErrorHandler
|
||||
}
|
||||
return p.defaultErrorHandler
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
transport := p.Transport
|
||||
ctx := req.Context()
|
||||
|
||||
outreq := req.Clone(ctx)
|
||||
if req.ContentLength == 0 {
|
||||
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||
}
|
||||
if outreq.Body != nil {
|
||||
// Reading from the request body after returning from a handler is not
|
||||
// allowed, and the RoundTrip goroutine that reads the Body can outlive
|
||||
// this handler. This can lead to a crash if the handler panics (see
|
||||
// Issue 46866). Although calling Close doesn't guarantee there isn't
|
||||
// any Read in flight after the handle returns, in practice it's safe to
|
||||
// read after closing it.
|
||||
defer outreq.Body.Close()
|
||||
}
|
||||
if outreq.Header == nil {
|
||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
p.Director(outreq)
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := upgradeType(outreq.Header)
|
||||
if !IsPrint(reqUpType) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Del("Forwarded")
|
||||
removeHopByHopHeaders(outreq.Header)
|
||||
|
||||
// Issue 21096: tell backend applications that care about trailer support
|
||||
// that we support trailers. (We do, but we don't go out of our way to
|
||||
// advertise that unless the incoming client request thought it was worth
|
||||
// mentioning.) Note that we look at req.Header, not outreq.Header, since
|
||||
// the latter has passed through removeHopByHopHeaders.
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
outreq.Header.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||
// If the outbound request doesn't have a User-Agent header set,
|
||||
// don't send the default Go HTTP client User-Agent.
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
var (
|
||||
roundTripMutex sync.Mutex
|
||||
roundTripDone bool
|
||||
)
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
roundTripMutex.Lock()
|
||||
defer roundTripMutex.Unlock()
|
||||
if roundTripDone {
|
||||
// If RoundTrip has returned, don't try to further modify
|
||||
// the ResponseWriter's header map.
|
||||
return nil
|
||||
}
|
||||
h := rw.Header()
|
||||
copyHeader(h, http.Header(header))
|
||||
rw.WriteHeader(code)
|
||||
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
clear(h)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
roundTripMutex.Lock()
|
||||
roundTripDone = true
|
||||
roundTripMutex.Unlock()
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, outreq, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
p.handleUpgradeResponse(rw, outreq, res)
|
||||
return
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(res.Header)
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response,
|
||||
// at least for *http.Transport. Build it up from Trailer.
|
||||
announcedTrailers := len(res.Trailer)
|
||||
if announcedTrailers > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
|
||||
err = utils.CopyCloseWithContext(ctx, rw, res.Body)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
p.getErrorHandler()(rw, req, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
http.NewResponseController(rw).Flush()
|
||||
}
|
||||
|
||||
if len(res.Trailer) == announcedTrailers {
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
return
|
||||
}
|
||||
|
||||
for k, vv := range res.Trailer {
|
||||
k = http.TrailerPrefix + k
|
||||
for _, v := range vv {
|
||||
rw.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders removes hop-by-hop headers.
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||
for _, f := range h["Connection"] {
|
||||
for sf := range strings.SplitSeq(f, ",") {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||
// preserve it for backwards compatibility.
|
||||
for _, f := range hopHeaders {
|
||||
h.Del(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) logf(format string, args ...any) {
|
||||
if p.ErrorLog != nil {
|
||||
p.ErrorLog.Printf(format, args...)
|
||||
} else {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func upgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(reqUpType, resUpType) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
|
||||
return
|
||||
}
|
||||
|
||||
rc := http.NewResponseController(rw)
|
||||
conn, brw, hijackErr := rc.Hijack()
|
||||
if errors.Is(hijackErr, http.ErrNotSupported) {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
if hijackErr != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
res.Header = rw.Header()
|
||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||
if err := res.Write(brw); err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
|
||||
return
|
||||
}
|
||||
if err := brw.Flush(); err != nil {
|
||||
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
|
||||
return
|
||||
}
|
||||
errc := make(chan error, 1)
|
||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
<-errc
|
||||
}
|
||||
|
||||
// switchProtocolCopier exists so goroutines proxying data back and
|
||||
// forth have nice names in stacks.
|
||||
type switchProtocolCopier struct {
|
||||
user, backend io.ReadWriter
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.user, c.backend)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.backend, c.user)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func IsPrint(s string) bool {
|
||||
for _, r := range s {
|
||||
if r < ' ' || r > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
Loading…
Add table
Reference in a new issue