mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-01 09:32:35 +02:00
fix: optimize memory usage, fix agent and code refactor (#118)
Some checks are pending
Docker Image CI (socket-proxy) / build (push) Waiting to run
Some checks are pending
Docker Image CI (socket-proxy) / build (push) Waiting to run
* refactor: simplify io code and make utils module independent * fix(docker): agent and socket-proxy docker event flushing with modified reverse proxy handler * refactor: remove unused code * refactor: remove the use of logging module in most code * refactor: streamline domain mismatch check in certState function * tweak: use ecdsa p-256 for autocert * fix(tests): update health check tests for invalid host and add case for port in host * feat(acme): custom acme directory * refactor: code refactor and improved context and error handling * tweak: optimize memory usage under load * fix(oidc): restore old user matching behavior * docs: add ChatGPT assistant to README --------- Co-authored-by: yusing <yusing@6uo.me>
This commit is contained in:
parent
ff08c40403
commit
4a8bd48ad5
98 changed files with 1549 additions and 555 deletions
|
@ -2,15 +2,16 @@ version: "2"
|
||||||
linters:
|
linters:
|
||||||
default: all
|
default: all
|
||||||
disable:
|
disable:
|
||||||
- bodyclose
|
# - bodyclose
|
||||||
- containedctx
|
- containedctx
|
||||||
- contextcheck
|
# - contextcheck
|
||||||
- cyclop
|
- cyclop
|
||||||
- depguard
|
- depguard
|
||||||
- dupl
|
# - dupl
|
||||||
- err113
|
- err113
|
||||||
- exhaustive
|
- exhaustive
|
||||||
- exhaustruct
|
- exhaustruct
|
||||||
|
- funcorder
|
||||||
- forcetypeassert
|
- forcetypeassert
|
||||||
- gochecknoglobals
|
- gochecknoglobals
|
||||||
- gochecknoinits
|
- gochecknoinits
|
||||||
|
@ -18,7 +19,6 @@ linters:
|
||||||
- goconst
|
- goconst
|
||||||
- gocyclo
|
- gocyclo
|
||||||
- gomoddirectives
|
- gomoddirectives
|
||||||
- gosec
|
|
||||||
- gosmopolitan
|
- gosmopolitan
|
||||||
- ireturn
|
- ireturn
|
||||||
- lll
|
- lll
|
||||||
|
@ -27,12 +27,10 @@ linters:
|
||||||
- mnd
|
- mnd
|
||||||
- nakedret
|
- nakedret
|
||||||
- nestif
|
- nestif
|
||||||
- nilnil
|
|
||||||
- nlreturn
|
- nlreturn
|
||||||
- noctx
|
|
||||||
- nonamedreturns
|
- nonamedreturns
|
||||||
- paralleltest
|
- paralleltest
|
||||||
- prealloc
|
- revive
|
||||||
- rowserrcheck
|
- rowserrcheck
|
||||||
- sqlclosecheck
|
- sqlclosecheck
|
||||||
- tagliatelle
|
- tagliatelle
|
||||||
|
|
|
@ -21,7 +21,7 @@ lint:
|
||||||
- markdownlint
|
- markdownlint
|
||||||
- yamllint
|
- yamllint
|
||||||
enabled:
|
enabled:
|
||||||
- checkov@3.2.416
|
- checkov@3.2.432
|
||||||
- golangci-lint2@2.1.6
|
- golangci-lint2@2.1.6
|
||||||
- hadolint@2.12.1-beta
|
- hadolint@2.12.1-beta
|
||||||
- actionlint@1.7.7
|
- actionlint@1.7.7
|
||||||
|
@ -32,7 +32,7 @@ lint:
|
||||||
- prettier@3.5.3
|
- prettier@3.5.3
|
||||||
- shellcheck@0.10.0
|
- shellcheck@0.10.0
|
||||||
- shfmt@3.6.0
|
- shfmt@3.6.0
|
||||||
- trufflehog@3.88.29
|
- trufflehog@3.88.33
|
||||||
actions:
|
actions:
|
||||||
disabled:
|
disabled:
|
||||||
- trunk-announce
|
- trunk-announce
|
||||||
|
|
|
@ -16,6 +16,8 @@ A lightweight, simple, and performant reverse proxy with WebUI.
|
||||||
|
|
||||||
<h5>EN | <a href="README_CHT.md">中文</a></h5>
|
<h5>EN | <a href="README_CHT.md">中文</a></h5>
|
||||||
|
|
||||||
|
Have questions? Ask [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)! (Thanks to [@ismesid](https://github.com/arevindh))
|
||||||
|
|
||||||
<img src="screenshots/webui.jpg" style="max-width: 650">
|
<img src="screenshots/webui.jpg" style="max-width: 650">
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
<h5><a href="README.md">EN</a> | 中文</h5>
|
<h5><a href="README.md">EN</a> | 中文</h5>
|
||||||
|
|
||||||
|
有疑問? 問 [ChatGPT](https://chatgpt.com/g/g-6825390374b481919ad482f2e48936a1-godoxy-assistant)!(鳴謝 [@ismesid](https://github.com/arevindh))
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/4bb371f4-6e4c-425c-89b2-b9e962bdd46f" style="max-width: 650">
|
<img src="https://github.com/user-attachments/assets/4bb371f4-6e4c-425c-89b2-b9e962bdd46f" style="max-width: 650">
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/server"
|
"github.com/yusing/go-proxy/agent/pkg/server"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||||
httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server"
|
httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -14,6 +17,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
writer := zerolog.ConsoleWriter{
|
||||||
|
Out: os.Stderr,
|
||||||
|
TimeFormat: "01-02 15:04",
|
||||||
|
}
|
||||||
|
zerolog.TimeFieldFormat = writer.TimeFormat
|
||||||
|
log.Logger = zerolog.New(writer).Level(zerolog.InfoLevel).With().Timestamp().Logger()
|
||||||
ca := &agent.PEMPair{}
|
ca := &agent.PEMPair{}
|
||||||
err := ca.Load(env.AgentCACert)
|
err := ca.Load(env.AgentCACert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -34,11 +43,11 @@ func main() {
|
||||||
gperr.LogFatal("init SSL error", err)
|
gperr.LogFatal("init SSL error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
|
log.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
|
||||||
logging.Info().Msgf("Agent name: %s", env.AgentName)
|
log.Info().Msgf("Agent name: %s", env.AgentName)
|
||||||
logging.Info().Msgf("Agent port: %d", env.AgentPort)
|
log.Info().Msgf("Agent port: %d", env.AgentPort)
|
||||||
|
|
||||||
logging.Info().Msg(`
|
log.Info().Msg(`
|
||||||
Tips:
|
Tips:
|
||||||
1. To change the agent name, you can set the AGENT_NAME environment variable.
|
1. To change the agent name, you can set the AGENT_NAME environment variable.
|
||||||
2. To change the agent port, you can set the AGENT_PORT environment variable.
|
2. To change the agent port, you can set the AGENT_PORT environment variable.
|
||||||
|
@ -54,7 +63,7 @@ Tips:
|
||||||
server.StartAgentServer(t, opts)
|
server.StartAgentServer(t, opts)
|
||||||
|
|
||||||
if socketproxy.ListenAddr != "" {
|
if socketproxy.ListenAddr != "" {
|
||||||
logging.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
|
log.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
|
||||||
opts := httpServer.Options{
|
opts := httpServer.Options{
|
||||||
Name: "docker",
|
Name: "docker",
|
||||||
HTTPAddr: socketproxy.ListenAddr,
|
HTTPAddr: socketproxy.ListenAddr,
|
||||||
|
|
|
@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy => ..
|
||||||
|
|
||||||
replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy
|
replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy
|
||||||
|
|
||||||
|
replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
|
||||||
|
|
||||||
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
|
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
|
||||||
|
|
||||||
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250523121925-f87c3159e327
|
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250523121925-f87c3159e327
|
||||||
|
@ -15,6 +17,7 @@ require (
|
||||||
github.com/rs/zerolog v1.34.0
|
github.com/rs/zerolog v1.34.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
|
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
|
||||||
|
github.com/yusing/go-proxy/internal/utils v0.0.0
|
||||||
github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000
|
github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/certs"
|
"github.com/yusing/go-proxy/agent/pkg/certs"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/pkg"
|
"github.com/yusing/go-proxy/pkg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||||
|
|
||||||
cfg.name = string(name)
|
cfg.name = string(name)
|
||||||
|
|
||||||
cfg.l = logging.With().Str("agent", cfg.name).Logger()
|
cfg.l = log.With().Str("agent", cfg.name).Logger()
|
||||||
|
|
||||||
// check agent version
|
// check agent version
|
||||||
agentVersionBytes, _, err := cfg.Fetch(ctx, EndpointVersion)
|
agentVersionBytes, _, err := cfg.Fetch(ctx, EndpointVersion)
|
||||||
|
@ -127,10 +127,10 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||||
agentVersion := pkg.ParseVersion(cfg.version)
|
agentVersion := pkg.ParseVersion(cfg.version)
|
||||||
|
|
||||||
if serverVersion.IsNewerMajorThan(agentVersion) {
|
if serverVersion.IsNewerMajorThan(agentVersion) {
|
||||||
logging.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
|
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.name, serverVersion, agentVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Info().Msgf("agent %q initialized", cfg.name)
|
log.Info().Msgf("agent %q initialized", cfg.name)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -172,9 +172,9 @@ func TestCheckHealthTCPUDP(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "InvalidHost",
|
name: "InvalidHost",
|
||||||
scheme: "tcp",
|
scheme: "tcp",
|
||||||
host: "invalid",
|
host: "",
|
||||||
port: 8080,
|
port: 8080,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedHealthy: false,
|
expectedHealthy: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -188,9 +188,17 @@ func TestCheckHealthTCPUDP(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "InvalidHost",
|
name: "InvalidHost",
|
||||||
scheme: "udp",
|
scheme: "udp",
|
||||||
host: "invalid",
|
host: "",
|
||||||
port: 8080,
|
port: 8080,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedHealthy: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Port in both host and port",
|
||||||
|
scheme: "tcp",
|
||||||
|
host: "localhost:1234",
|
||||||
|
port: 1234,
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
expectedHealthy: false,
|
expectedHealthy: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -208,9 +216,11 @@ func TestCheckHealthTCPUDP(t *testing.T) {
|
||||||
|
|
||||||
require.Equal(t, recorder.Code, tt.expectedStatus)
|
require.Equal(t, recorder.Code, tt.expectedStatus)
|
||||||
|
|
||||||
var result health.HealthCheckResult
|
if tt.expectedStatus == http.StatusOK {
|
||||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
|
var result health.HealthCheckResult
|
||||||
require.Equal(t, result.Healthy, tt.expectedHealthy)
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &result))
|
||||||
|
require.Equal(t, result.Healthy, tt.expectedHealthy)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||||
"github.com/yusing/go-proxy/pkg"
|
"github.com/yusing/go-proxy/pkg"
|
||||||
|
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServeMux struct{ *http.ServeMux }
|
type ServeMux struct{ *http.ServeMux }
|
||||||
|
@ -24,26 +21,6 @@ func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
|
||||||
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
|
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
|
|
||||||
|
|
||||||
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
return dialer.DialContext(ctx, "unix", env.DockerSocket)
|
|
||||||
}
|
|
||||||
|
|
||||||
func dockerSocketHandler() http.HandlerFunc {
|
|
||||||
rp := httputil.ReverseProxy{
|
|
||||||
Director: func(r *http.Request) {
|
|
||||||
r.URL.Scheme = "http"
|
|
||||||
r.URL.Host = "api.moby.localhost"
|
|
||||||
r.RequestURI = r.URL.String()
|
|
||||||
},
|
|
||||||
Transport: &http.Transport{
|
|
||||||
DialContext: dialDockerSocket,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return rp.ServeHTTP
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAgentHandler() http.Handler {
|
func NewAgentHandler() http.Handler {
|
||||||
mux := ServeMux{http.NewServeMux()}
|
mux := ServeMux{http.NewServeMux()}
|
||||||
|
|
||||||
|
@ -54,6 +31,6 @@ func NewAgentHandler() http.Handler {
|
||||||
})
|
})
|
||||||
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
||||||
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
|
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
|
||||||
mux.ServeMux.HandleFunc("/", dockerSocketHandler())
|
mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler(env.DockerSocket))
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/env"
|
"github.com/yusing/go-proxy/agent/pkg/env"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/handler"
|
"github.com/yusing/go-proxy/agent/pkg/handler"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/server"
|
"github.com/yusing/go-proxy/internal/net/gphttp/server"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
)
|
)
|
||||||
|
@ -33,12 +33,11 @@ func StartAgentServer(parent task.Parent, opt Options) {
|
||||||
tlsConfig.ClientAuth = tls.NoClientCert
|
tlsConfig.ClientAuth = tls.NoClientCert
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logging.GetLogger()
|
|
||||||
agentServer := &http.Server{
|
agentServer := &http.Server{
|
||||||
Addr: fmt.Sprintf(":%d", opt.Port),
|
Addr: fmt.Sprintf(":%d", opt.Port),
|
||||||
Handler: handler.NewAgentHandler(),
|
Handler: handler.NewAgentHandler(),
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
server.Start(parent, agentServer, nil, logger)
|
server.Start(parent, agentServer, nil, &log.Logger)
|
||||||
}
|
}
|
||||||
|
|
11
cmd/main.go
11
cmd/main.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/auth"
|
"github.com/yusing/go-proxy/internal/auth"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/config"
|
"github.com/yusing/go-proxy/internal/config"
|
||||||
|
@ -35,8 +36,8 @@ func main() {
|
||||||
initProfiling()
|
initProfiling()
|
||||||
|
|
||||||
logging.InitLogger(os.Stderr, memlogger.GetMemLogger())
|
logging.InitLogger(os.Stderr, memlogger.GetMemLogger())
|
||||||
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
log.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
||||||
logging.Trace().Msg("trace enabled")
|
log.Trace().Msg("trace enabled")
|
||||||
parallel(
|
parallel(
|
||||||
dnsproviders.InitProviders,
|
dnsproviders.InitProviders,
|
||||||
homepage.InitIconListCache,
|
homepage.InitIconListCache,
|
||||||
|
@ -45,7 +46,7 @@ func main() {
|
||||||
)
|
)
|
||||||
|
|
||||||
if common.APIJWTSecret == nil {
|
if common.APIJWTSecret == nil {
|
||||||
logging.Warn().Msg("API_JWT_SECRET is not set, using random key")
|
log.Warn().Msg("API_JWT_SECRET is not set, using random key")
|
||||||
common.APIJWTSecret = common.RandomJWTKey()
|
common.APIJWTSecret = common.RandomJWTKey()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +63,7 @@ func main() {
|
||||||
Proxy: true,
|
Proxy: true,
|
||||||
})
|
})
|
||||||
if err := auth.Initialize(); err != nil {
|
if err := auth.Initialize(); err != nil {
|
||||||
logging.Fatal().Err(err).Msg("failed to initialize authentication")
|
log.Fatal().Err(err).Msg("failed to initialize authentication")
|
||||||
}
|
}
|
||||||
// API Handler needs to start after auth is initialized.
|
// API Handler needs to start after auth is initialized.
|
||||||
cfg.StartServers(&config.StartServersOptions{
|
cfg.StartServers(&config.StartServersOptions{
|
||||||
|
@ -78,7 +79,7 @@ func main() {
|
||||||
func prepareDirectory(dir string) {
|
func prepareDirectory(dir string) {
|
||||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||||
if err = os.MkdirAll(dir, 0o755); err != nil {
|
if err = os.MkdirAll(dir, 0o755); err != nil {
|
||||||
logging.Fatal().Msgf("failed to create directory %s: %v", dir, err)
|
log.Fatal().Msgf("failed to create directory %s: %v", dir, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
7
go.mod
7
go.mod
|
@ -6,6 +6,8 @@ replace github.com/yusing/go-proxy/agent => ./agent
|
||||||
|
|
||||||
replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders
|
replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders
|
||||||
|
|
||||||
|
replace github.com/yusing/go-proxy/internal/utils => ./internal/utils
|
||||||
|
|
||||||
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.0.0-20250523122447-f078841dec22
|
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.0.0-20250523122447-f078841dec22
|
||||||
|
|
||||||
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
|
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250523125835-a2474a6ebe30
|
||||||
|
@ -45,7 +47,7 @@ require (
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
github.com/yusing/go-proxy/agent v0.0.0-00010101000000-000000000000
|
github.com/yusing/go-proxy/agent v0.0.0-00010101000000-000000000000
|
||||||
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-00010101000000-000000000000
|
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-00010101000000-000000000000
|
||||||
go.uber.org/atomic v1.11.0
|
github.com/yusing/go-proxy/internal/utils v0.0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
@ -219,6 +221,7 @@ require (
|
||||||
go.opentelemetry.io/otel v1.36.0 // indirect
|
go.opentelemetry.io/otel v1.36.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||||
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||||
go.uber.org/mock v0.5.2 // indirect
|
go.uber.org/mock v0.5.2 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
|
@ -226,7 +229,7 @@ require (
|
||||||
golang.org/x/mod v0.24.0 // indirect
|
golang.org/x/mod v0.24.0 // indirect
|
||||||
golang.org/x/sync v0.14.0 // indirect
|
golang.org/x/sync v0.14.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
golang.org/x/text v0.25.0
|
golang.org/x/text v0.25.0 // indirect
|
||||||
golang.org/x/tools v0.33.0 // indirect
|
golang.org/x/tools v0.33.0 // indirect
|
||||||
google.golang.org/api v0.234.0 // indirect
|
google.golang.org/api v0.234.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect
|
||||||
|
|
|
@ -5,9 +5,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/maxmind"
|
"github.com/yusing/go-proxy/internal/maxmind"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -45,7 +45,7 @@ func (c *checkCache) Expired() bool {
|
||||||
return c.created.Add(cacheTTL).Before(utils.TimeNow())
|
return c.created.Add(cacheTTL).Before(utils.TimeNow())
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: add stats
|
// TODO: add stats
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ACLAllow = "allow"
|
ACLAllow = "allow"
|
||||||
|
@ -97,7 +97,7 @@ func (c *Config) Start(parent *task.Task) gperr.Error {
|
||||||
if c.valErr != nil {
|
if c.valErr != nil {
|
||||||
return c.valErr
|
return c.valErr
|
||||||
}
|
}
|
||||||
logging.Info().
|
log.Info().
|
||||||
Str("default", c.Default).
|
Str("default", c.Default).
|
||||||
Bool("allow_local", c.allowLocal).
|
Bool("allow_local", c.allowLocal).
|
||||||
Int("allow_rules", len(c.Allow)).
|
Int("allow_rules", len(c.Allow)).
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMatchers(t *testing.T) {
|
func TestMatchers(t *testing.T) {
|
||||||
|
@ -16,7 +16,7 @@ func TestMatchers(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var mathers Matchers
|
var mathers Matchers
|
||||||
err := utils.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
|
err := serialization.Convert(reflect.ValueOf(strMatchers), reflect.ValueOf(&mathers), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,12 +22,12 @@ func (noConn) SetDeadline(t time.Time) error { return nil }
|
||||||
func (noConn) SetReadDeadline(t time.Time) error { return nil }
|
func (noConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
func (noConn) SetWriteDeadline(t time.Time) error { return nil }
|
func (noConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
func (cfg *Config) WrapTCP(lis net.Listener) net.Listener {
|
func (c *Config) WrapTCP(lis net.Listener) net.Listener {
|
||||||
if cfg == nil {
|
if c == nil {
|
||||||
return lis
|
return lis
|
||||||
}
|
}
|
||||||
return &TCPListener{
|
return &TCPListener{
|
||||||
acl: cfg,
|
acl: c,
|
||||||
lis: lis,
|
lis: lis,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,9 +3,9 @@ package certapi
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
|
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
|
||||||
)
|
)
|
||||||
|
@ -36,7 +36,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) {
|
||||||
gperr.LogError("failed to obtain cert", err)
|
gperr.LogError("failed to obtain cert", err)
|
||||||
_ = gpwebsocket.WriteText(conn, err.Error())
|
_ = gpwebsocket.WriteText(conn, err.Error())
|
||||||
} else {
|
} else {
|
||||||
logging.Info().Msg("cert obtained successfully")
|
log.Info().Msg("cert obtained successfully")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
|
|
|
@ -9,12 +9,13 @@ import (
|
||||||
"github.com/docker/docker/api/types/container"
|
"github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/docker/pkg/stdcopy"
|
"github.com/docker/docker/pkg/stdcopy"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
|
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// FIXME: agent logs not updating.
|
||||||
func Logs(w http.ResponseWriter, r *http.Request) {
|
func Logs(w http.ResponseWriter, r *http.Request) {
|
||||||
query := r.URL.Query()
|
query := r.URL.Query()
|
||||||
server := r.PathValue("server")
|
server := r.PathValue("server")
|
||||||
|
@ -68,7 +69,7 @@ func Logs(w http.ResponseWriter, r *http.Request) {
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logging.Err(err).
|
log.Err(err).
|
||||||
Str("server", server).
|
Str("server", server).
|
||||||
Str("container", containerID).
|
Str("container", containerID).
|
||||||
Msg("failed to de-multiplex logs")
|
Msg("failed to de-multiplex logs")
|
||||||
|
|
|
@ -11,9 +11,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/jsonstore"
|
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -108,11 +108,11 @@ func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
||||||
RefreshToken: token,
|
RefreshToken: token,
|
||||||
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
||||||
})
|
})
|
||||||
logging.Debug().Str("username", username).Msg("stored oauth refresh token")
|
log.Debug().Str("username", username).Msg("stored oauth refresh token")
|
||||||
}
|
}
|
||||||
|
|
||||||
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
||||||
logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
|
log.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
|
||||||
oauthRefreshTokens.Delete(string(sessionID))
|
oauthRefreshTokens.Delete(string(sessionID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
||||||
signed, err := jwtToken.SignedString(common.APIJWTSecret)
|
signed, err := jwtToken.SignedString(common.APIJWTSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Err(err).Msg("failed to sign session token")
|
log.Err(err).Msg("failed to sign session token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
|
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
|
||||||
|
@ -190,7 +190,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
|
||||||
return nil, refreshToken.err
|
return nil, refreshToken.err
|
||||||
}
|
}
|
||||||
|
|
||||||
idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken)
|
idTokenJWT, idToken, err := auth.getIDToken(ctx, newToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
|
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
|
||||||
return nil, refreshToken.err
|
return nil, refreshToken.err
|
||||||
|
@ -205,7 +205,7 @@ func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oaut
|
||||||
|
|
||||||
sessionID := newSessionID()
|
sessionID := newSessionID()
|
||||||
|
|
||||||
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
log.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||||
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
||||||
|
|
||||||
refreshToken.result = &RefreshResult{
|
refreshToken.result = &RefreshResult{
|
||||||
|
|
|
@ -12,9 +12,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -38,8 +38,8 @@ type (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CookieOauthState = "godoxy_oidc_state"
|
CookieOauthState = "godoxy_oidc_state"
|
||||||
CookieOauthToken = "godoxy_oauth_token"
|
CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
|
||||||
CookieOauthSessionToken = "godoxy_session_token"
|
CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -79,7 +79,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all
|
||||||
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
|
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
|
||||||
if err != nil && provider.EndSessionEndpoint() != "" {
|
if err != nil && provider.EndSessionEndpoint() != "" {
|
||||||
// non critical, just warn
|
// non critical, just warn
|
||||||
logging.Warn().
|
log.Warn().
|
||||||
Str("issuer", issuerURL).
|
Str("issuer", issuerURL).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("failed to parse end session URL")
|
Msg("failed to parse end session URL")
|
||||||
|
@ -129,7 +129,7 @@ func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
||||||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
|
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
func (auth *OIDCProvider) getIDToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
||||||
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", nil, errMissingIDToken
|
return "", nil, errMissingIDToken
|
||||||
|
@ -176,7 +176,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// clear cookies then redirect to home
|
// clear cookies then redirect to home
|
||||||
logging.Err(err).Msg("failed to refresh token")
|
log.Err(err).Msg("failed to refresh token")
|
||||||
auth.clearCookie(w, r)
|
auth.clearCookie(w, r)
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
return
|
return
|
||||||
|
@ -201,11 +201,12 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
||||||
|
|
||||||
func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
||||||
userAllowed := slices.Contains(auth.allowedUsers, user)
|
userAllowed := slices.Contains(auth.allowedUsers, user)
|
||||||
if !userAllowed {
|
if userAllowed {
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
if len(auth.allowedGroups) == 0 {
|
if len(auth.allowedGroups) == 0 {
|
||||||
return true
|
// user is not allowed, but no groups are allowed
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
return len(utils.Intersect(groups, auth.allowedGroups)) > 0
|
return len(utils.Intersect(groups, auth.allowedGroups)) > 0
|
||||||
}
|
}
|
||||||
|
@ -257,7 +258,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token)
|
idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gphttp.ServerError(w, r, err)
|
gphttp.ServerError(w, r, err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -5,13 +5,16 @@ import (
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/go-acme/lego/v4/certcrypto"
|
"github.com/go-acme/lego/v4/certcrypto"
|
||||||
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,13 +25,19 @@ type Config struct {
|
||||||
KeyPath string `json:"key_path,omitempty"`
|
KeyPath string `json:"key_path,omitempty"`
|
||||||
ACMEKeyPath string `json:"acme_key_path,omitempty"`
|
ACMEKeyPath string `json:"acme_key_path,omitempty"`
|
||||||
Provider string `json:"provider,omitempty"`
|
Provider string `json:"provider,omitempty"`
|
||||||
|
CADirURL string `json:"ca_dir_url,omitempty"`
|
||||||
Options map[string]any `json:"options,omitempty"`
|
Options map[string]any `json:"options,omitempty"`
|
||||||
|
|
||||||
|
HTTPClient *http.Client `json:"-"` // for tests only
|
||||||
|
|
||||||
|
challengeProvider challenge.Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingDomain = gperr.New("missing field 'domains'")
|
ErrMissingDomain = gperr.New("missing field 'domains'")
|
||||||
ErrMissingEmail = gperr.New("missing field 'email'")
|
ErrMissingEmail = gperr.New("missing field 'email'")
|
||||||
ErrMissingProvider = gperr.New("missing field 'provider'")
|
ErrMissingProvider = gperr.New("missing field 'provider'")
|
||||||
|
ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
|
||||||
ErrInvalidDomain = gperr.New("invalid domain")
|
ErrInvalidDomain = gperr.New("invalid domain")
|
||||||
ErrUnknownProvider = gperr.New("unknown provider")
|
ErrUnknownProvider = gperr.New("unknown provider")
|
||||||
)
|
)
|
||||||
|
@ -36,6 +45,7 @@ var (
|
||||||
const (
|
const (
|
||||||
ProviderLocal = "local"
|
ProviderLocal = "local"
|
||||||
ProviderPseudo = "pseudo"
|
ProviderPseudo = "pseudo"
|
||||||
|
ProviderCustom = "custom"
|
||||||
)
|
)
|
||||||
|
|
||||||
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
||||||
|
@ -52,6 +62,10 @@ func (cfg *Config) Validate() gperr.Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
b := gperr.NewBuilder("autocert errors")
|
b := gperr.NewBuilder("autocert errors")
|
||||||
|
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
|
||||||
|
b.Add(ErrMissingCADirURL)
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||||
if len(cfg.Domains) == 0 {
|
if len(cfg.Domains) == 0 {
|
||||||
b.Add(ErrMissingDomain)
|
b.Add(ErrMissingDomain)
|
||||||
|
@ -59,24 +73,34 @@ func (cfg *Config) Validate() gperr.Error {
|
||||||
if cfg.Email == "" {
|
if cfg.Email == "" {
|
||||||
b.Add(ErrMissingEmail)
|
b.Add(ErrMissingEmail)
|
||||||
}
|
}
|
||||||
for i, d := range cfg.Domains {
|
if cfg.Provider != ProviderCustom {
|
||||||
if !domainOrWildcardRE.MatchString(d) {
|
for i, d := range cfg.Domains {
|
||||||
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
if !domainOrWildcardRE.MatchString(d) {
|
||||||
|
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// check if provider is implemented
|
// check if provider is implemented
|
||||||
providerConstructor, ok := Providers[cfg.Provider]
|
providerConstructor, ok := Providers[cfg.Provider]
|
||||||
if !ok {
|
if !ok {
|
||||||
b.Add(ErrUnknownProvider.
|
if cfg.Provider != ProviderCustom {
|
||||||
Subject(cfg.Provider).
|
b.Add(ErrUnknownProvider.
|
||||||
With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
|
Subject(cfg.Provider).
|
||||||
|
With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers))))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err := providerConstructor(cfg.Options)
|
provider, err := providerConstructor(cfg.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Add(err)
|
b.Add(err)
|
||||||
|
} else {
|
||||||
|
cfg.challengeProvider = provider
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.challengeProvider == nil {
|
||||||
|
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
|
||||||
|
}
|
||||||
return b.Error()
|
return b.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,8 +124,7 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
||||||
|
|
||||||
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||||
if privKey, err = cfg.LoadACMEKey(); err != nil {
|
if privKey, err = cfg.LoadACMEKey(); err != nil {
|
||||||
logging.Info().Err(err).Msg("load ACME private key failed")
|
log.Info().Err(err).Msg("failed to load ACME private key, generating a now one")
|
||||||
logging.Info().Msg("generate new ACME private key")
|
|
||||||
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, gperr.New("generate ACME private key").With(err)
|
return nil, nil, gperr.New("generate ACME private key").With(err)
|
||||||
|
@ -118,12 +141,23 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
legoCfg := lego.NewConfig(user)
|
legoCfg := lego.NewConfig(user)
|
||||||
legoCfg.Certificate.KeyType = certcrypto.RSA2048
|
legoCfg.Certificate.KeyType = certcrypto.EC256
|
||||||
|
|
||||||
|
if cfg.HTTPClient != nil {
|
||||||
|
legoCfg.HTTPClient = cfg.HTTPClient
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.CADirURL != "" {
|
||||||
|
legoCfg.CADirURL = cfg.CADirURL
|
||||||
|
}
|
||||||
|
|
||||||
return user, legoCfg, nil
|
return user, legoCfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||||
|
if common.IsTest {
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
data, err := os.ReadFile(cfg.ACMEKeyPath)
|
data, err := os.ReadFile(cfg.ACMEKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -132,6 +166,9 @@ func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error {
|
func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error {
|
||||||
|
if common.IsTest {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
data, err := x509.MarshalECPrivateKey(key)
|
data, err := x509.MarshalECPrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -5,18 +5,20 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"slices"
|
||||||
"sort"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-acme/lego/v4/certificate"
|
"github.com/go-acme/lego/v4/certificate"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
"github.com/go-acme/lego/v4/registration"
|
"github.com/go-acme/lego/v4/registration"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/notif"
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
@ -76,13 +78,11 @@ func (p *Provider) ObtainCert() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.cfg.Provider == ProviderPseudo {
|
if p.cfg.Provider == ProviderPseudo {
|
||||||
t := time.NewTicker(1000 * time.Millisecond)
|
log.Info().Msg("init client for pseudo provider")
|
||||||
defer t.Stop()
|
<-time.After(time.Second)
|
||||||
logging.Info().Msg("init client for pseudo provider")
|
log.Info().Msg("registering acme for pseudo provider")
|
||||||
<-t.C
|
<-time.After(time.Second)
|
||||||
logging.Info().Msg("registering acme for pseudo provider")
|
log.Info().Msg("obtained cert for pseudo provider")
|
||||||
<-t.C
|
|
||||||
logging.Info().Msg("obtained cert for pseudo provider")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ func (p *Provider) ObtainCert() error {
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.legoCert = nil
|
p.legoCert = nil
|
||||||
logging.Err(err).Msg("cert renew failed, fallback to obtain")
|
log.Err(err).Msg("cert renew failed, fallback to obtain")
|
||||||
} else {
|
} else {
|
||||||
p.legoCert = cert
|
p.legoCert = cert
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ func (p *Provider) LoadCert() error {
|
||||||
p.tlsCert = &cert
|
p.tlsCert = &cert
|
||||||
p.certExpiries = expiries
|
p.certExpiries = expiries
|
||||||
|
|
||||||
logging.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
log.Info().Msgf("next renewal in %v", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||||
return p.renewIfNeeded()
|
return p.renewIfNeeded()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,13 +219,7 @@ func (p *Provider) initClient() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
generator := Providers[p.cfg.Provider]
|
err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider)
|
||||||
legoProvider, pErr := generator(p.cfg.Options)
|
|
||||||
if pErr != nil {
|
|
||||||
return pErr
|
|
||||||
}
|
|
||||||
|
|
||||||
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -240,7 +234,7 @@ func (p *Provider) registerACME() error {
|
||||||
}
|
}
|
||||||
if reg, err := p.client.Registration.ResolveAccountByKey(); err == nil {
|
if reg, err := p.client.Registration.ResolveAccountByKey(); err == nil {
|
||||||
p.user.Registration = reg
|
p.user.Registration = reg
|
||||||
logging.Info().Msg("reused acme registration from private key")
|
log.Info().Msg("reused acme registration from private key")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,11 +243,14 @@ func (p *Provider) registerACME() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.user.Registration = reg
|
p.user.Registration = reg
|
||||||
logging.Info().Interface("reg", reg).Msg("acme registered")
|
log.Info().Interface("reg", reg).Msg("acme registered")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) saveCert(cert *certificate.Resource) error {
|
func (p *Provider) saveCert(cert *certificate.Resource) error {
|
||||||
|
if common.IsTest {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
/* This should have been done in setup
|
/* This should have been done in setup
|
||||||
but double check is always a good choice.*/
|
but double check is always a good choice.*/
|
||||||
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
_, err := os.Stat(path.Dir(p.cfg.CertPath))
|
||||||
|
@ -283,22 +280,19 @@ func (p *Provider) certState() CertState {
|
||||||
return CertStateExpired
|
return CertStateExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
certDomains := make([]string, len(p.certExpiries))
|
if len(p.certExpiries) != len(p.cfg.Domains) {
|
||||||
wantedDomains := make([]string, len(p.cfg.Domains))
|
|
||||||
i := 0
|
|
||||||
for domain := range p.certExpiries {
|
|
||||||
certDomains[i] = domain
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
copy(wantedDomains, p.cfg.Domains)
|
|
||||||
sort.Strings(wantedDomains)
|
|
||||||
sort.Strings(certDomains)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(certDomains, wantedDomains) {
|
|
||||||
logging.Info().Msgf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
|
|
||||||
return CertStateMismatch
|
return CertStateMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i := range len(p.cfg.Domains) {
|
||||||
|
if _, ok := p.certExpiries[p.cfg.Domains[i]]; !ok {
|
||||||
|
log.Info().Msgf("autocert domains mismatch: cert: %s, wanted: %s",
|
||||||
|
strings.Join(slices.Collect(maps.Keys(p.certExpiries)), ", "),
|
||||||
|
strings.Join(p.cfg.Domains, ", "))
|
||||||
|
return CertStateMismatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return CertStateValid
|
return CertStateValid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,9 +303,9 @@ func (p *Provider) renewIfNeeded() error {
|
||||||
|
|
||||||
switch p.certState() {
|
switch p.certState() {
|
||||||
case CertStateExpired:
|
case CertStateExpired:
|
||||||
logging.Info().Msg("certs expired, renewing")
|
log.Info().Msg("certs expired, renewing")
|
||||||
case CertStateMismatch:
|
case CertStateMismatch:
|
||||||
logging.Info().Msg("cert domains mismatch with config, renewing")
|
log.Info().Msg("cert domains mismatch with config, renewing")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
453
internal/autocert/provider_test/custom_test.go
Normal file
453
internal/autocert/provider_test/custom_test.go
Normal file
|
@ -0,0 +1,453 @@
|
||||||
|
//nolint:errchkjson,errcheck
|
||||||
|
package provider_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/go-proxy/internal/autocert"
|
||||||
|
"github.com/yusing/go-proxy/internal/dnsproviders"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
dnsproviders.InitProviders()
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomProvider(t *testing.T) {
|
||||||
|
t.Run("valid custom provider with step-ca", func(t *testing.T) {
|
||||||
|
cfg := &autocert.Config{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Domains: []string{"example.com", "*.example.com"},
|
||||||
|
Provider: autocert.ProviderCustom,
|
||||||
|
CADirURL: "https://ca.example.com:9000/acme/acme/directory",
|
||||||
|
CertPath: "certs/custom.crt",
|
||||||
|
KeyPath: "certs/custom.key",
|
||||||
|
ACMEKeyPath: "certs/custom-acme.key",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.Validate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user, legoCfg, err := cfg.GetLegoConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.NotNil(t, legoCfg)
|
||||||
|
require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", legoCfg.CADirURL)
|
||||||
|
require.Equal(t, "test@example.com", user.Email)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom provider missing CADirURL", func(t *testing.T) {
|
||||||
|
cfg := &autocert.Config{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
Provider: autocert.ProviderCustom,
|
||||||
|
// CADirURL is missing
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
|
||||||
|
cfg := &autocert.Config{
|
||||||
|
Email: "admin@internal.com",
|
||||||
|
Domains: []string{"internal.example.com", "api.internal.example.com"},
|
||||||
|
Provider: autocert.ProviderCustom,
|
||||||
|
CADirURL: "https://step-ca.internal:443/acme/acme/directory",
|
||||||
|
CertPath: "certs/internal.crt",
|
||||||
|
KeyPath: "certs/internal.key",
|
||||||
|
ACMEKeyPath: "certs/internal-acme.key",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.Validate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user, legoCfg, err := cfg.GetLegoConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.NotNil(t, legoCfg)
|
||||||
|
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
|
||||||
|
require.Equal(t, "admin@internal.com", user.Email)
|
||||||
|
|
||||||
|
provider := autocert.NewProvider(cfg, user, legoCfg)
|
||||||
|
require.NotNil(t, provider)
|
||||||
|
require.Equal(t, autocert.ProviderCustom, provider.GetName())
|
||||||
|
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
|
||||||
|
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestObtainCertFromCustomProvider(t *testing.T) {
|
||||||
|
// Create a test ACME server
|
||||||
|
acmeServer := newTestACMEServer(t)
|
||||||
|
defer acmeServer.Close()
|
||||||
|
|
||||||
|
t.Run("obtain cert from custom step-ca server", func(t *testing.T) {
|
||||||
|
cfg := &autocert.Config{
|
||||||
|
Email: "test@example.com",
|
||||||
|
Domains: []string{"test.example.com"},
|
||||||
|
Provider: autocert.ProviderCustom,
|
||||||
|
CADirURL: acmeServer.URL() + "/acme/acme/directory",
|
||||||
|
CertPath: "certs/stepca-test.crt",
|
||||||
|
KeyPath: "certs/stepca-test.key",
|
||||||
|
ACMEKeyPath: "certs/stepca-test-acme.key",
|
||||||
|
HTTPClient: acmeServer.httpClient(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := error(cfg.Validate())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user, legoCfg, err := cfg.GetLegoConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
require.NotNil(t, legoCfg)
|
||||||
|
|
||||||
|
provider := autocert.NewProvider(cfg, user, legoCfg)
|
||||||
|
require.NotNil(t, provider)
|
||||||
|
|
||||||
|
// Test obtaining certificate
|
||||||
|
err = provider.ObtainCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify certificate was obtained
|
||||||
|
cert, err := provider.GetCert(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, cert)
|
||||||
|
|
||||||
|
// Verify certificate properties
|
||||||
|
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, x509Cert.DNSNames, "test.example.com")
|
||||||
|
require.True(t, time.Now().Before(x509Cert.NotAfter))
|
||||||
|
require.True(t, time.Now().After(x509Cert.NotBefore))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testACMEServer implements a minimal ACME server for testing.
|
||||||
|
type testACMEServer struct {
|
||||||
|
server *httptest.Server
|
||||||
|
caCert *x509.Certificate
|
||||||
|
caKey *rsa.PrivateKey
|
||||||
|
clientCSRs map[string]*x509.CertificateRequest
|
||||||
|
orderID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestACMEServer(t *testing.T) *testACMEServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate CA certificate and key
|
||||||
|
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
caTemplate := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Test CA"},
|
||||||
|
Country: []string{"US"},
|
||||||
|
Province: []string{""},
|
||||||
|
Locality: []string{"Test"},
|
||||||
|
StreetAddress: []string{""},
|
||||||
|
PostalCode: []string{""},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
caCert, err := x509.ParseCertificate(caCertDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
acme := &testACMEServer{
|
||||||
|
caCert: caCert,
|
||||||
|
caKey: caKey,
|
||||||
|
clientCSRs: make(map[string]*x509.CertificateRequest),
|
||||||
|
orderID: "test-order-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
acme.setupRoutes(mux)
|
||||||
|
|
||||||
|
acme.server = httptest.NewTLSServer(mux)
|
||||||
|
return acme
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) Close() {
|
||||||
|
s.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) URL() string {
|
||||||
|
return s.server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) httpClient() *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: 30 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 30 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true, //nolint:gosec
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
|
||||||
|
// ACME directory endpoint
|
||||||
|
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
|
||||||
|
|
||||||
|
// ACME endpoints
|
||||||
|
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
|
||||||
|
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
|
||||||
|
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
|
||||||
|
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
|
||||||
|
mux.HandleFunc("/acme/chall/", s.handleChallenge)
|
||||||
|
mux.HandleFunc("/acme/order/", s.handleOrder)
|
||||||
|
mux.HandleFunc("/acme/cert/", s.handleCertificate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
|
directory := map[string]interface{}{
|
||||||
|
"newNonce": s.server.URL + "/acme/new-nonce",
|
||||||
|
"newAccount": s.server.URL + "/acme/new-account",
|
||||||
|
"newOrder": s.server.URL + "/acme/new-order",
|
||||||
|
"keyChange": s.server.URL + "/acme/key-change",
|
||||||
|
"meta": map[string]interface{}{
|
||||||
|
"termsOfService": s.server.URL + "/terms",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(directory)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-12345")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account := map[string]interface{}{
|
||||||
|
"status": "valid",
|
||||||
|
"contact": []string{"mailto:test@example.com"},
|
||||||
|
"orders": s.server.URL + "/acme/orders",
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Location", s.server.URL+"/acme/account/1")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-67890")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authzID := "test-authz-456"
|
||||||
|
|
||||||
|
order := map[string]interface{}{
|
||||||
|
"status": "ready", // Skip pending state for simplicity
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
||||||
|
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
|
||||||
|
"finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-order")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authz := map[string]interface{}{
|
||||||
|
"status": "valid", // Skip challenge validation for simplicity
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifier": map[string]string{"type": "dns", "value": "test.example.com"},
|
||||||
|
"challenges": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"type": "dns-01",
|
||||||
|
"status": "valid",
|
||||||
|
"url": s.server.URL + "/acme/chall/test-chall-789",
|
||||||
|
"token": "test-token-abc123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-authz")
|
||||||
|
json.NewEncoder(w).Encode(authz)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
|
challenge := map[string]interface{}{
|
||||||
|
"type": "dns-01",
|
||||||
|
"status": "valid",
|
||||||
|
"url": r.URL.String(),
|
||||||
|
"token": "test-token-abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-chall")
|
||||||
|
json.NewEncoder(w).Encode(challenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.HasSuffix(r.URL.Path, "/finalize") {
|
||||||
|
s.handleFinalize(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certURL := s.server.URL + "/acme/cert/" + s.orderID
|
||||||
|
order := map[string]interface{}{
|
||||||
|
"status": "valid",
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
||||||
|
"certificate": certURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
|
||||||
|
json.NewEncoder(w).Encode(order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Read the JWS payload
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to read request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract CSR from JWS payload
|
||||||
|
csr, err := s.extractCSRFromJWS(body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the CSR for certificate generation
|
||||||
|
s.clientCSRs[s.orderID] = csr
|
||||||
|
|
||||||
|
certURL := s.server.URL + "/acme/cert/" + s.orderID
|
||||||
|
order := map[string]interface{}{
|
||||||
|
"status": "valid",
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
||||||
|
"certificate": certURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
|
||||||
|
json.NewEncoder(w).Encode(order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
|
||||||
|
// Parse the JWS structure
|
||||||
|
var jws struct {
|
||||||
|
Protected string `json:"protected"`
|
||||||
|
Payload string `json:"payload"`
|
||||||
|
Signature string `json:"signature"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jwsData, &jws); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the payload
|
||||||
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the finalize request
|
||||||
|
var finalizeReq struct {
|
||||||
|
CSR string `json:"csr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the CSR
|
||||||
|
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the CSR
|
||||||
|
csr, err := x509.ParseCertificateRequest(csrBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return csr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Extract order ID from URL
|
||||||
|
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
|
||||||
|
|
||||||
|
// Get the CSR for this order
|
||||||
|
csr, exists := s.clientCSRs[orderID]
|
||||||
|
if !exists {
|
||||||
|
http.Error(w, "No CSR found for order", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate using the public key from the client's CSR
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Test Cert"},
|
||||||
|
Country: []string{"US"},
|
||||||
|
},
|
||||||
|
DNSNames: csr.DNSNames,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(90 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the public key from the CSR and sign with CA key
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return certificate chain
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-cert")
|
||||||
|
w.Write(append(certPEM, caPEM...))
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"github.com/go-acme/lego/v4/providers/dns/ovh"
|
"github.com/go-acme/lego/v4/providers/dns/ovh"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
// type Config struct {
|
// type Config struct {
|
||||||
|
@ -45,6 +45,6 @@ oauth2_config:
|
||||||
testYaml = testYaml[1:] // remove first \n
|
testYaml = testYaml[1:] // remove first \n
|
||||||
opt := make(map[string]any)
|
opt := make(map[string]any)
|
||||||
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
|
require.NoError(t, yaml.Unmarshal([]byte(testYaml), &opt))
|
||||||
require.NoError(t, utils.MapUnmarshalValidate(opt, cfg))
|
require.NoError(t, serialization.MapUnmarshalValidate(opt, cfg))
|
||||||
require.Equal(t, cfgExpected, cfg)
|
require.Equal(t, cfgExpected, cfg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ package autocert
|
||||||
import (
|
import (
|
||||||
"github.com/go-acme/lego/v4/challenge"
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Generator func(map[string]any) (challenge.Provider, gperr.Error)
|
type Generator func(map[string]any) (challenge.Provider, gperr.Error)
|
||||||
|
@ -16,9 +16,11 @@ func DNSProvider[CT any, PT challenge.Provider](
|
||||||
) Generator {
|
) Generator {
|
||||||
return func(opt map[string]any) (challenge.Provider, gperr.Error) {
|
return func(opt map[string]any) (challenge.Provider, gperr.Error) {
|
||||||
cfg := defaultCfg()
|
cfg := defaultCfg()
|
||||||
err := utils.MapUnmarshalValidate(opt, &cfg)
|
if len(opt) > 0 {
|
||||||
if err != nil {
|
err := serialization.MapUnmarshalValidate(opt, &cfg)
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
p, pErr := newProvider(cfg)
|
p, pErr := newProvider(cfg)
|
||||||
return p, gperr.Wrap(pErr)
|
return p, gperr.Wrap(pErr)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,14 +13,14 @@ func (p *Provider) Setup() (err error) {
|
||||||
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
|
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logging.Debug().Msg("obtaining cert due to error loading cert")
|
log.Debug().Msg("obtaining cert due to error loading cert")
|
||||||
if err = p.ObtainCert(); err != nil {
|
if err = p.ObtainCert(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, expiry := range p.GetExpiries() {
|
for _, expiry := range p.GetExpiries() {
|
||||||
logging.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
|
log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,20 +10,20 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/api"
|
"github.com/yusing/go-proxy/internal/api"
|
||||||
autocert "github.com/yusing/go-proxy/internal/autocert"
|
autocert "github.com/yusing/go-proxy/internal/autocert"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
"github.com/yusing/go-proxy/internal/entrypoint"
|
"github.com/yusing/go-proxy/internal/entrypoint"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/maxmind"
|
"github.com/yusing/go-proxy/internal/maxmind"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/server"
|
"github.com/yusing/go-proxy/internal/net/gphttp/server"
|
||||||
"github.com/yusing/go-proxy/internal/notif"
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/proxmox"
|
"github.com/yusing/go-proxy/internal/proxmox"
|
||||||
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
proxy "github.com/yusing/go-proxy/internal/route/provider"
|
||||||
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
|
||||||
"github.com/yusing/go-proxy/internal/watcher"
|
"github.com/yusing/go-proxy/internal/watcher"
|
||||||
|
@ -96,10 +96,10 @@ func OnConfigChange(ev []events.Event) {
|
||||||
// just reload once and check the last event
|
// just reload once and check the last event
|
||||||
switch ev[len(ev)-1].Action {
|
switch ev[len(ev)-1].Action {
|
||||||
case events.ActionFileRenamed:
|
case events.ActionFileRenamed:
|
||||||
logging.Warn().Msg(cfgRenameWarn)
|
log.Warn().Msg(cfgRenameWarn)
|
||||||
return
|
return
|
||||||
case events.ActionFileDeleted:
|
case events.ActionFileDeleted:
|
||||||
logging.Warn().Msg(cfgDeleteWarn)
|
log.Warn().Msg(cfgDeleteWarn)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ func (cfg *Config) Start(opts ...*StartServersOptions) {
|
||||||
func (cfg *Config) StartAutoCert() {
|
func (cfg *Config) StartAutoCert() {
|
||||||
autocert := cfg.autocertProvider
|
autocert := cfg.autocertProvider
|
||||||
if autocert == nil {
|
if autocert == nil {
|
||||||
logging.Info().Msg("autocert not configured")
|
log.Info().Msg("autocert not configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ func (cfg *Config) load() gperr.Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
model := config.DefaultConfig()
|
model := config.DefaultConfig()
|
||||||
if err := utils.UnmarshalValidateYAML(data, model); err != nil {
|
if err := serialization.UnmarshalValidateYAML(data, model); err != nil {
|
||||||
gperr.LogFatal(errMsg, err)
|
gperr.LogFatal(errMsg, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -374,6 +374,6 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
|
||||||
}
|
}
|
||||||
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
|
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
|
||||||
})
|
})
|
||||||
logging.Info().Msg(results.String())
|
log.Info().Msg(results.String())
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||||
"github.com/yusing/go-proxy/internal/notif"
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/proxmox"
|
"github.com/yusing/go-proxy/internal/proxmox"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -93,14 +93,14 @@ func HasInstance() bool {
|
||||||
|
|
||||||
func Validate(data []byte) gperr.Error {
|
func Validate(data []byte) gperr.Error {
|
||||||
var model Config
|
var model Config
|
||||||
return utils.UnmarshalValidateYAML(data, &model)
|
return serialization.UnmarshalValidateYAML(data, &model)
|
||||||
}
|
}
|
||||||
|
|
||||||
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
|
var matchDomainsRegex = regexp.MustCompile(`^[^\.]?([\w\d\-_]\.?)+[^\.]?$`)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
utils.RegisterDefaultValueFactory(DefaultConfig)
|
serialization.RegisterDefaultValueFactory(DefaultConfig)
|
||||||
utils.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
|
serialization.MustRegisterValidation("domain_name", func(fl validator.FieldLevel) bool {
|
||||||
domains := fl.Field().Interface().([]string)
|
domains := fl.Field().Interface().([]string)
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if !matchDomainsRegex.MatchString(domain) {
|
if !matchDomainsRegex.MatchString(domain) {
|
||||||
|
@ -109,7 +109,7 @@ func init() {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
|
serialization.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool {
|
||||||
m := fl.Field().Interface().(map[string]string)
|
m := fl.Field().Interface().(map[string]string)
|
||||||
for k := range m {
|
for k := range m {
|
||||||
if k == "" {
|
if k == "" {
|
||||||
|
|
|
@ -4,6 +4,8 @@ go 1.24.3
|
||||||
|
|
||||||
replace github.com/yusing/go-proxy => ../..
|
replace github.com/yusing/go-proxy => ../..
|
||||||
|
|
||||||
|
replace github.com/yusing/go-proxy/internal/utils => ../utils
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-acme/lego/v4 v4.23.1
|
github.com/go-acme/lego/v4 v4.23.1
|
||||||
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
|
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
|
||||||
|
@ -156,6 +158,7 @@ require (
|
||||||
github.com/vultr/govultr/v3 v3.20.0 // indirect
|
github.com/vultr/govultr/v3 v3.20.0 // indirect
|
||||||
github.com/x448/float16 v0.8.4 // indirect
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||||
|
github.com/yusing/go-proxy/internal/utils v0.0.0 // indirect
|
||||||
go.mongodb.org/mongo-driver v1.17.3 // indirect
|
go.mongodb.org/mongo-driver v1.17.3 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
|
||||||
|
|
|
@ -13,13 +13,14 @@ import (
|
||||||
|
|
||||||
"github.com/docker/cli/cli/connhelper"
|
"github.com/docker/cli/cli/connhelper"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: implement reconnect here.
|
||||||
type (
|
type (
|
||||||
SharedClient struct {
|
SharedClient struct {
|
||||||
*client.Client
|
*client.Client
|
||||||
|
@ -83,7 +84,7 @@ func closeTimedOutClients() {
|
||||||
if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs {
|
if atomic.LoadUint32(&c.refCount) == 0 && now-atomic.LoadInt64(&c.closedOn) > clientTTLSecs {
|
||||||
delete(clientMap, c.Key())
|
delete(clientMap, c.Key())
|
||||||
c.Client.Close()
|
c.Client.Close()
|
||||||
logging.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
|
log.Debug().Str("host", c.DaemonHost()).Msg("docker client closed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,7 +149,7 @@ func NewClient(host string) (*SharedClient, error) {
|
||||||
default:
|
default:
|
||||||
helper, err := connhelper.GetConnectionHelper(host)
|
helper, err := connhelper.GetConnectionHelper(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Panic().Err(err).Msg("failed to get connection helper")
|
log.Panic().Err(err).Msg("failed to get connection helper")
|
||||||
}
|
}
|
||||||
if helper != nil {
|
if helper != nil {
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
|
@ -189,10 +190,10 @@ func NewClient(host string) (*SharedClient, error) {
|
||||||
c.dial = client.Dialer()
|
c.dial = client.Dialer()
|
||||||
}
|
}
|
||||||
if c.addr == "" {
|
if c.addr == "" {
|
||||||
c.addr = c.Client.DaemonHost()
|
c.addr = c.DaemonHost()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer logging.Debug().Str("host", host).Msg("docker client initialized")
|
defer log.Debug().Str("host", host).Msg("docker client initialized")
|
||||||
|
|
||||||
clientMap[c.Key()] = c
|
clientMap[c.Key()] = c
|
||||||
return c, nil
|
return c, nil
|
||||||
|
|
|
@ -9,11 +9,12 @@ import (
|
||||||
|
|
||||||
"github.com/docker/docker/api/types/container"
|
"github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/go-connections/nat"
|
"github.com/docker/go-connections/nat"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -90,7 +91,7 @@ func FromDocker(c *container.SummaryTrimmed, dockerHost string) (res *Container)
|
||||||
var ok bool
|
var ok bool
|
||||||
res.Agent, ok = config.GetInstance().GetAgent(dockerHost)
|
res.Agent, ok = config.GetInstance().GetAgent(dockerHost)
|
||||||
if !ok {
|
if !ok {
|
||||||
logging.Error().Msgf("agent %q not found", dockerHost)
|
log.Error().Msgf("agent %q not found", dockerHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +184,7 @@ func (c *Container) setPublicHostname() {
|
||||||
}
|
}
|
||||||
url, err := url.Parse(c.DockerHost)
|
url, err := url.Parse(c.DockerHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
|
log.Err(err).Msgf("invalid docker host %q, falling back to 127.0.0.1", c.DockerHost)
|
||||||
c.PublicHostname = "127.0.0.1"
|
c.PublicHostname = "127.0.0.1"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -224,7 +225,7 @@ func (c *Container) loadDeleteIdlewatcherLabels(helper containerHelper) {
|
||||||
ContainerName: c.ContainerName,
|
ContainerName: c.ContainerName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := utils.MapUnmarshalValidate(cfg, idwCfg)
|
err := serialization.MapUnmarshalValidate(cfg, idwCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err))
|
gperr.LogWarn("invalid idlewatcher config", gperr.PrependSubject(c.ContainerName, err))
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
"github.com/yusing/go-proxy/internal/net/gphttp/middleware"
|
||||||
|
@ -50,7 +50,7 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
||||||
}
|
}
|
||||||
ep.middleware = mid
|
ep.middleware = mid
|
||||||
|
|
||||||
logging.Debug().Msg("entrypoint middleware loaded")
|
log.Debug().Msg("entrypoint middleware loaded")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logging.Debug().Msg("entrypoint access logger created")
|
log.Debug().Msg("entrypoint access logger created")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Then scraper / scanners will know the subdomain is invalid.
|
// Then scraper / scanners will know the subdomain is invalid.
|
||||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||||
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||||
logging.Err(err).
|
log.Err(err).
|
||||||
Str("method", r.Method).
|
Str("method", r.Method).
|
||||||
Str("url", r.URL.String()).
|
Str("url", r.URL.String()).
|
||||||
Str("remote", r.RemoteAddr).
|
Str("remote", r.RemoteAddr).
|
||||||
|
@ -99,7 +99,7 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
if _, err := w.Write(errorPage); err != nil {
|
if _, err := w.Write(errorPage); err != nil {
|
||||||
logging.Err(err).Msg("failed to write error page")
|
log.Err(err).Msg("failed to write error page")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
http.Error(w, err.Error(), http.StatusNotFound)
|
http.Error(w, err.Error(), http.StatusNotFound)
|
||||||
|
|
|
@ -4,8 +4,8 @@ import (
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
zerologlog "github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
|
func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger) {
|
||||||
|
@ -13,7 +13,7 @@ func log(msg string, err error, level zerolog.Level, logger ...*zerolog.Logger)
|
||||||
if len(logger) > 0 {
|
if len(logger) > 0 {
|
||||||
l = logger[0]
|
l = logger[0]
|
||||||
} else {
|
} else {
|
||||||
l = logging.GetLogger()
|
l = &zerologlog.Logger
|
||||||
}
|
}
|
||||||
l.WithLevel(level).Msg(New(highlightANSI(msg)).With(err).Error())
|
l.WithLevel(level).Msg(New(highlightANSI(msg)).With(err).Error())
|
||||||
switch level {
|
switch level {
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/homepage/widgets"
|
"github.com/yusing/go-proxy/internal/homepage/widgets"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -32,7 +32,7 @@ type (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
utils.RegisterDefaultValueFactory(func() *ItemConfig {
|
serialization.RegisterDefaultValueFactory(func() *ItemConfig {
|
||||||
return &ItemConfig{
|
return &ItemConfig{
|
||||||
Show: true,
|
Show: true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,9 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/jsonstore"
|
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||||
|
@ -74,7 +74,7 @@ func pruneExpiredIconCache() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if nPruned > 0 {
|
if nPruned > 0 {
|
||||||
logging.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
|
log.Info().Int("pruned", nPruned).Msg("pruned expired icon cache")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ func loadIconCache(key string) *FetchResult {
|
||||||
defer iconMu.RUnlock()
|
defer iconMu.RUnlock()
|
||||||
icon, ok := iconCache.Load(key)
|
icon, ok := iconCache.Load(key)
|
||||||
if ok && len(icon.Icon) > 0 {
|
if ok && len(icon.Icon) > 0 {
|
||||||
logging.Debug().
|
log.Debug().
|
||||||
Str("key", key).
|
Str("key", key).
|
||||||
Msg("icon found in cache")
|
Msg("icon found in cache")
|
||||||
icon.LastAccess.Store(utils.TimeNow())
|
icon.LastAccess.Store(utils.TimeNow())
|
||||||
|
@ -99,7 +99,7 @@ func loadIconCache(key string) *FetchResult {
|
||||||
func storeIconCache(key string, result *FetchResult) {
|
func storeIconCache(key string, result *FetchResult) {
|
||||||
icon := result.Icon
|
icon := result.Icon
|
||||||
if len(icon) > maxIconSize {
|
if len(icon) > maxIconSize {
|
||||||
logging.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
|
log.Debug().Int("size", len(icon)).Msg("icon cache size exceeds max cache size")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ func storeIconCache(key string, result *FetchResult) {
|
||||||
entry := &cacheEntry{Icon: icon, ContentType: result.contentType}
|
entry := &cacheEntry{Icon: icon, ContentType: result.contentType}
|
||||||
entry.LastAccess.Store(time.Now())
|
entry.LastAccess.Store(time.Now())
|
||||||
iconCache.Store(key, entry)
|
iconCache.Store(key, entry)
|
||||||
logging.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
|
log.Debug().Str("key", key).Int("size", len(icon)).Msg("stored icon cache")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *cacheEntry) IsExpired() bool {
|
func (e *cacheEntry) IsExpired() bool {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package homepage
|
package homepage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -10,10 +11,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,30 +47,30 @@ type (
|
||||||
func (icon *IconMeta) Filenames(ref string) []string {
|
func (icon *IconMeta) Filenames(ref string) []string {
|
||||||
filenames := make([]string, 0)
|
filenames := make([]string, 0)
|
||||||
if icon.SVG {
|
if icon.SVG {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s.svg", ref))
|
filenames = append(filenames, ref+".svg")
|
||||||
if icon.Light {
|
if icon.Light {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s-light.svg", ref))
|
filenames = append(filenames, ref+"-light.svg")
|
||||||
}
|
}
|
||||||
if icon.Dark {
|
if icon.Dark {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s-dark.svg", ref))
|
filenames = append(filenames, ref+"-dark.svg")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if icon.PNG {
|
if icon.PNG {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s.png", ref))
|
filenames = append(filenames, ref+".png")
|
||||||
if icon.Light {
|
if icon.Light {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s-light.png", ref))
|
filenames = append(filenames, ref+"-light.png")
|
||||||
}
|
}
|
||||||
if icon.Dark {
|
if icon.Dark {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s-dark.png", ref))
|
filenames = append(filenames, ref+"-dark.png")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if icon.WebP {
|
if icon.WebP {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s.webp", ref))
|
filenames = append(filenames, ref+".webp")
|
||||||
if icon.Light {
|
if icon.Light {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s-light.webp", ref))
|
filenames = append(filenames, ref+"-light.webp")
|
||||||
}
|
}
|
||||||
if icon.Dark {
|
if icon.Dark {
|
||||||
filenames = append(filenames, fmt.Sprintf("%s-dark.webp", ref))
|
filenames = append(filenames, ref+"-dark.webp")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return filenames
|
return filenames
|
||||||
|
@ -99,21 +100,21 @@ func InitIconListCache() {
|
||||||
iconsCache.Lock()
|
iconsCache.Lock()
|
||||||
defer iconsCache.Unlock()
|
defer iconsCache.Unlock()
|
||||||
|
|
||||||
err := utils.LoadJSONIfExist(common.IconListCachePath, iconsCache)
|
err := serialization.LoadJSONIfExist(common.IconListCachePath, iconsCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error().Err(err).Msg("failed to load icons")
|
log.Error().Err(err).Msg("failed to load icons")
|
||||||
} else if len(iconsCache.Icons) > 0 {
|
} else if len(iconsCache.Icons) > 0 {
|
||||||
logging.Info().
|
log.Info().
|
||||||
Int("icons", len(iconsCache.Icons)).
|
Int("icons", len(iconsCache.Icons)).
|
||||||
Msg("icons loaded")
|
Msg("icons loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = updateIcons(); err != nil {
|
if err = updateIcons(); err != nil {
|
||||||
logging.Error().Err(err).Msg("failed to update icons")
|
log.Error().Err(err).Msg("failed to update icons")
|
||||||
}
|
}
|
||||||
|
|
||||||
task.OnProgramExit("save_icons_cache", func() {
|
task.OnProgramExit("save_icons_cache", func() {
|
||||||
utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
_ = serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,17 +135,17 @@ func ListAvailableIcons() (*Cache, error) {
|
||||||
iconsCache.Lock()
|
iconsCache.Lock()
|
||||||
defer iconsCache.Unlock()
|
defer iconsCache.Unlock()
|
||||||
|
|
||||||
logging.Info().Msg("updating icon data")
|
log.Info().Msg("updating icon data")
|
||||||
if err := updateIcons(); err != nil {
|
if err := updateIcons(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logging.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
|
log.Info().Int("icons", len(iconsCache.Icons)).Msg("icons list updated")
|
||||||
|
|
||||||
iconsCache.LastUpdate = time.Now()
|
iconsCache.LastUpdate = time.Now()
|
||||||
|
|
||||||
err := utils.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
err := serialization.SaveJSON(common.IconListCachePath, iconsCache, 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Warn().Err(err).Msg("failed to save icons")
|
log.Warn().Err(err).Msg("failed to save icons")
|
||||||
}
|
}
|
||||||
return iconsCache, nil
|
return iconsCache, nil
|
||||||
}
|
}
|
||||||
|
@ -230,14 +231,17 @@ func updateIcons() error {
|
||||||
|
|
||||||
var httpGet = httpGetImpl
|
var httpGet = httpGetImpl
|
||||||
|
|
||||||
func MockHttpGet(body []byte) {
|
func MockHTTPGet(body []byte) {
|
||||||
httpGet = func(_ string) ([]byte, error) {
|
httpGet = func(_ string) ([]byte, error) {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpGetImpl(url string) ([]byte, error) {
|
func httpGetImpl(url string) ([]byte, error) {
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -347,7 +351,7 @@ func UpdateSelfhstIcons() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
data := make([]SelfhStIcon, 0)
|
data := make([]SelfhStIcon, 0)
|
||||||
err = json.Unmarshal(body, &data)
|
err = json.Unmarshal(body, &data) //nolint:musttag
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,6 +68,8 @@ type testCases struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
|
func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
for _, item := range test {
|
for _, item := range test {
|
||||||
icon, ok := iconsCache.Icons[item.Key]
|
icon, ok := iconsCache.Icons[item.Key]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -89,7 +91,7 @@ func runTests(t *testing.T, iconsCache *Cache, test []testCases) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListWalkxCodeIcons(t *testing.T) {
|
func TestListWalkxCodeIcons(t *testing.T) {
|
||||||
MockHttpGet([]byte(walkxcodeIcons))
|
MockHTTPGet([]byte(walkxcodeIcons))
|
||||||
if err := UpdateWalkxCodeIcons(); err != nil {
|
if err := UpdateWalkxCodeIcons(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -122,7 +124,7 @@ func TestListWalkxCodeIcons(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListSelfhstIcons(t *testing.T) {
|
func TestListSelfhstIcons(t *testing.T) {
|
||||||
MockHttpGet([]byte(selfhstIcons))
|
MockHTTPGet([]byte(selfhstIcons))
|
||||||
if err := UpdateSelfhstIcons(); err != nil {
|
if err := UpdateSelfhstIcons(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -33,17 +33,18 @@ var widgetProviders = map[string]struct{}{
|
||||||
var ErrInvalidProvider = gperr.New("invalid provider")
|
var ErrInvalidProvider = gperr.New("invalid provider")
|
||||||
|
|
||||||
func (cfg *Config) UnmarshalMap(m map[string]any) error {
|
func (cfg *Config) UnmarshalMap(m map[string]any) error {
|
||||||
cfg.Provider = m["provider"].(string)
|
var ok bool
|
||||||
|
cfg.Provider, ok = m["provider"].(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidProvider.Withf("non string")
|
||||||
|
}
|
||||||
if _, ok := widgetProviders[cfg.Provider]; !ok {
|
if _, ok := widgetProviders[cfg.Provider]; !ok {
|
||||||
return ErrInvalidProvider.Subject(cfg.Provider)
|
return ErrInvalidProvider.Subject(cfg.Provider)
|
||||||
}
|
}
|
||||||
delete(m, "provider")
|
delete(m, "provider")
|
||||||
m, ok := m["config"].(map[string]any)
|
m, ok = m["config"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return gperr.New("invalid config")
|
return gperr.New("invalid config")
|
||||||
}
|
}
|
||||||
if err := utils.MapUnmarshalValidate(m, &cfg.Config); err != nil {
|
return serialization.MapUnmarshalValidate(m, &cfg.Config)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,10 +7,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/idlewatcher/provider"
|
"github.com/yusing/go-proxy/internal/idlewatcher/provider"
|
||||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||||
net "github.com/yusing/go-proxy/internal/net/types"
|
net "github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/route/routes"
|
"github.com/yusing/go-proxy/internal/route/routes"
|
||||||
|
@ -73,13 +73,13 @@ var dummyHealthCheckConfig = &health.HealthCheckConfig{
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
causeReload = gperr.New("reloaded")
|
causeReload = gperr.New("reloaded") //nolint:errname
|
||||||
causeContainerDestroy = gperr.New("container destroyed")
|
causeContainerDestroy = gperr.New("container destroyed") //nolint:errname
|
||||||
)
|
)
|
||||||
|
|
||||||
const reqTimeout = 3 * time.Second
|
const reqTimeout = 3 * time.Second
|
||||||
|
|
||||||
// TODO: fix stream type
|
// TODO: fix stream type.
|
||||||
func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
||||||
cfg := r.IdlewatcherConfig()
|
cfg := r.IdlewatcherConfig()
|
||||||
key := cfg.Key()
|
key := cfg.Key()
|
||||||
|
@ -120,7 +120,7 @@ func NewWatcher(parent task.Parent, r routes.Route) (*Watcher, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
w.provider = p
|
w.provider = p
|
||||||
w.l = logging.With().
|
w.l = log.With().
|
||||||
Str("provider", providerType).
|
Str("provider", providerType).
|
||||||
Str("container", cfg.ContainerName()).
|
Str("container", cfg.ContainerName()).
|
||||||
Logger()
|
Logger()
|
||||||
|
|
|
@ -2,18 +2,17 @@ package jsonstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"maps"
|
|
||||||
|
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type namespace string
|
type namespace string
|
||||||
|
@ -36,13 +35,15 @@ type store interface {
|
||||||
json.Unmarshaler
|
json.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
var stores = make(map[namespace]store)
|
var (
|
||||||
var storesPath = common.DataDir
|
stores = make(map[namespace]store)
|
||||||
|
storesPath = common.DataDir
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
task.OnProgramExit("save_stores", func() {
|
task.OnProgramExit("save_stores", func() {
|
||||||
if err := save(); err != nil {
|
if err := save(); err != nil {
|
||||||
logging.Error().Err(err).Msg("failed to save stores")
|
log.Error().Err(err).Msg("failed to save stores")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -54,20 +55,20 @@ func loadNS[T store](ns namespace) T {
|
||||||
file, err := os.Open(path)
|
file, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
logging.Err(err).
|
log.Err(err).
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Msg("failed to load store")
|
Msg("failed to load store")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
if err := json.NewDecoder(file).Decode(&store); err != nil {
|
if err := json.NewDecoder(file).Decode(&store); err != nil {
|
||||||
logging.Err(err).
|
log.Err(err).
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Msg("failed to load store")
|
Msg("failed to load store")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stores[ns] = store
|
stores[ns] = store
|
||||||
logging.Debug().
|
log.Debug().
|
||||||
Str("namespace", string(ns)).
|
Str("namespace", string(ns)).
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Msg("loaded store")
|
Msg("loaded store")
|
||||||
|
@ -77,7 +78,7 @@ func loadNS[T store](ns namespace) T {
|
||||||
func save() error {
|
func save() error {
|
||||||
errs := gperr.NewBuilder("failed to save data stores")
|
errs := gperr.NewBuilder("failed to save data stores")
|
||||||
for ns, store := range stores {
|
for ns, store := range stores {
|
||||||
if err := utils.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
|
if err := serialization.SaveJSON(filepath.Join(storesPath, string(ns)+".json"), &store, 0o644); err != nil {
|
||||||
errs.Add(err)
|
errs.Add(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -86,7 +87,7 @@ func save() error {
|
||||||
|
|
||||||
func Store[VT any](namespace namespace) MapStore[VT] {
|
func Store[VT any](namespace namespace) MapStore[VT] {
|
||||||
if _, ok := stores[namespace]; ok {
|
if _, ok := stores[namespace]; ok {
|
||||||
logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
||||||
}
|
}
|
||||||
store := loadNS[*MapStore[VT]](namespace)
|
store := loadNS[*MapStore[VT]](namespace)
|
||||||
stores[namespace] = store
|
stores[namespace] = store
|
||||||
|
@ -95,7 +96,7 @@ func Store[VT any](namespace namespace) MapStore[VT] {
|
||||||
|
|
||||||
func Object[Ptr Initializer](namespace namespace) Ptr {
|
func Object[Ptr Initializer](namespace namespace) Ptr {
|
||||||
if _, ok := stores[namespace]; ok {
|
if _, ok := stores[namespace]; ok {
|
||||||
logging.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
log.Fatal().Str("namespace", string(namespace)).Msg("namespace already exists")
|
||||||
}
|
}
|
||||||
obj := loadNS[*ObjectStore[Ptr]](namespace)
|
obj := loadNS[*ObjectStore[Ptr]](namespace)
|
||||||
stores[namespace] = obj
|
stores[namespace] = obj
|
||||||
|
@ -117,7 +118,7 @@ func (s *MapStore[VT]) UnmarshalJSON(data []byte) error {
|
||||||
}
|
}
|
||||||
s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp)))
|
s.Map = xsync.NewMap[string, VT](xsync.WithPresize(len(tmp)))
|
||||||
for k, v := range tmp {
|
for k, v := range tmp {
|
||||||
s.Map.Store(k, v)
|
s.Store(k, v)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
maxmind "github.com/yusing/go-proxy/internal/maxmind/types"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
@ -83,6 +83,9 @@ func NewAccessLogger(parent task.Parent, cfg AnyConfig) (*AccessLogger, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if io == nil {
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
return NewAccessLoggerWithIO(parent, io, cfg), nil
|
return NewAccessLoggerWithIO(parent, io, cfg), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +123,7 @@ func NewAccessLoggerWithIO(parent task.Parent, writer WriterWithName, anyCfg Any
|
||||||
bufSize: MinBufferSize,
|
bufSize: MinBufferSize,
|
||||||
lineBufPool: synk.NewBytesPool(),
|
lineBufPool: synk.NewBytesPool(),
|
||||||
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
|
errRateLimiter: rate.NewLimiter(rate.Every(errRateLimit), errBurst),
|
||||||
logger: logging.With().Str("file", writer.Name()).Logger(),
|
logger: log.With().Str("file", writer.Name()).Logger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
l.supportRotate = unwrap[supportRotate](writer)
|
l.supportRotate = unwrap[supportRotate](writer)
|
||||||
|
@ -181,7 +184,7 @@ func (l *AccessLogger) LogError(req *http.Request, err error) {
|
||||||
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
|
func (l *AccessLogger) LogACL(info *maxmind.IPInfo, blocked bool) {
|
||||||
line := l.lineBufPool.Get()
|
line := l.lineBufPool.Get()
|
||||||
defer l.lineBufPool.Put(line)
|
defer l.lineBufPool.Put(line)
|
||||||
line = l.ACLFormatter.AppendACLLog(line, info, blocked)
|
line = l.AppendACLLog(line, info, blocked)
|
||||||
if line[len(line)-1] != '\n' {
|
if line[len(line)-1] != '\n' {
|
||||||
line = append(line, '\n')
|
line = append(line, '\n')
|
||||||
}
|
}
|
||||||
|
@ -194,7 +197,7 @@ func (l *AccessLogger) ShouldRotate() bool {
|
||||||
|
|
||||||
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
|
func (l *AccessLogger) Rotate() (result *RotateResult, err error) {
|
||||||
if !l.ShouldRotate() {
|
if !l.ShouldRotate() {
|
||||||
return nil, nil
|
return nil, nil //nolint:nilnil
|
||||||
}
|
}
|
||||||
|
|
||||||
l.writer.Flush()
|
l.writer.Flush()
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -126,6 +126,6 @@ func DefaultACLLoggerConfig() *ACLLoggerConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
utils.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
|
serialization.RegisterDefaultValueFactory(DefaultRequestLoggerConfig)
|
||||||
utils.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
|
serialization.RegisterDefaultValueFactory(DefaultACLLoggerConfig)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
. "github.com/yusing/go-proxy/internal/logging/accesslog"
|
. "github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) {
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
|
|
||||||
var config RequestLoggerConfig
|
var config RequestLoggerConfig
|
||||||
err = utils.MapUnmarshalValidate(parsed, &config)
|
err = serialization.MapUnmarshalValidate(parsed, &config)
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
|
|
||||||
expect.Equal(t, config.Format, FormatCombined)
|
expect.Equal(t, config.Format, FormatCombined)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,20 +35,19 @@ func newFileIO(path string) (SupportRotate, error) {
|
||||||
if opened, ok := openedFiles[path]; ok {
|
if opened, ok := openedFiles[path]; ok {
|
||||||
opened.refCount.Add()
|
opened.refCount.Add()
|
||||||
return opened, nil
|
return opened, nil
|
||||||
} else {
|
|
||||||
// cannot open as O_APPEND as we need Seek and WriteAt
|
|
||||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("access log open error: %w", err)
|
|
||||||
}
|
|
||||||
if _, err := f.Seek(0, io.SeekEnd); err != nil {
|
|
||||||
return nil, fmt.Errorf("access log seek error: %w", err)
|
|
||||||
}
|
|
||||||
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
|
|
||||||
openedFiles[path] = file
|
|
||||||
go file.closeOnZero()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cannot open as O_APPEND as we need Seek and WriteAt
|
||||||
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("access log open error: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := f.Seek(0, io.SeekEnd); err != nil {
|
||||||
|
return nil, fmt.Errorf("access log seek error: %w", err)
|
||||||
|
}
|
||||||
|
file = &File{f: f, path: path, refCount: utils.NewRefCounter()}
|
||||||
|
openedFiles[path] = file
|
||||||
|
go file.closeOnZero()
|
||||||
return file, nil
|
return file, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,7 +89,7 @@ func (f *File) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *File) closeOnZero() {
|
func (f *File) closeOnZero() {
|
||||||
defer logging.Debug().
|
defer log.Debug().
|
||||||
Str("path", f.path).
|
Str("path", f.path).
|
||||||
Msg("access log closed")
|
Msg("access log closed")
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
//nolint:zerologlint
|
|
||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -10,6 +9,8 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
|
||||||
|
zerologlog "github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -61,22 +62,6 @@ func InitLogger(out ...io.Writer) {
|
||||||
log.SetOutput(writer)
|
log.SetOutput(writer)
|
||||||
log.SetPrefix("")
|
log.SetPrefix("")
|
||||||
log.SetFlags(0)
|
log.SetFlags(0)
|
||||||
|
zerolog.TimeFieldFormat = timeFmt
|
||||||
|
zerologlog.Logger = logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func DiscardLogger() { zerolog.SetGlobalLevel(zerolog.Disabled) }
|
|
||||||
|
|
||||||
func AddHook(h zerolog.Hook) { logger = logger.Hook(h) }
|
|
||||||
|
|
||||||
func GetLogger() *zerolog.Logger { return &logger }
|
|
||||||
func With() zerolog.Context { return logger.With() }
|
|
||||||
|
|
||||||
func WithLevel(level zerolog.Level) *zerolog.Event { return logger.WithLevel(level) }
|
|
||||||
|
|
||||||
func Info() *zerolog.Event { return logger.Info() }
|
|
||||||
func Warn() *zerolog.Event { return logger.Warn() }
|
|
||||||
func Error() *zerolog.Event { return logger.Error() }
|
|
||||||
func Err(err error) *zerolog.Event { return logger.Err(err) }
|
|
||||||
func Debug() *zerolog.Event { return logger.Debug() }
|
|
||||||
func Fatal() *zerolog.Event { return logger.Fatal() }
|
|
||||||
func Panic() *zerolog.Event { return logger.Panic() }
|
|
||||||
func Trace() *zerolog.Event { return logger.Trace() }
|
|
||||||
|
|
|
@ -2,8 +2,8 @@ package maxmind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -28,6 +28,6 @@ func (cfg *Config) Validate() gperr.Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *Config) Logger() *zerolog.Logger {
|
func (cfg *Config) Logger() *zerolog.Logger {
|
||||||
l := logging.With().Str("database", string(cfg.Database)).Logger()
|
l := log.With().Str("database", string(cfg.Database)).Logger()
|
||||||
return &l
|
return &l
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,8 +10,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||||
)
|
)
|
||||||
|
@ -47,7 +47,7 @@ var initDataDirOnce sync.Once
|
||||||
|
|
||||||
func initDataDir() {
|
func initDataDir() {
|
||||||
if err := os.MkdirAll(saveBaseDir, 0o755); err != nil {
|
if err := os.MkdirAll(saveBaseDir, 0o755); err != nil {
|
||||||
logging.Error().Err(err).Msg("failed to create metrics data directory")
|
log.Error().Err(err).Msg("failed to create metrics data directory")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ func NewPoller[T any, AggregateT json.Marshaler](
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Poller[T, AggregateT]) savePath() string {
|
func (p *Poller[T, AggregateT]) savePath() string {
|
||||||
return filepath.Join(saveBaseDir, fmt.Sprintf("%s.json", p.name))
|
return filepath.Join(saveBaseDir, p.name+".json")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Poller[T, AggregateT]) load() error {
|
func (p *Poller[T, AggregateT]) load() error {
|
||||||
|
@ -135,13 +135,14 @@ func (p *Poller[T, AggregateT]) pollWithTimeout(ctx context.Context) {
|
||||||
|
|
||||||
func (p *Poller[T, AggregateT]) Start() {
|
func (p *Poller[T, AggregateT]) Start() {
|
||||||
t := task.RootTask("poller." + p.name)
|
t := task.RootTask("poller." + p.name)
|
||||||
|
l := log.With().Str("name", p.name).Logger()
|
||||||
err := p.load()
|
err := p.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
logging.Error().Err(err).Msgf("failed to load last metrics data for %s", p.name)
|
l.Err(err).Msg("failed to load last metrics data")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logging.Debug().Msgf("Loaded last metrics data for %s, %d entries", p.name, p.period.Total())
|
l.Debug().Int("entries", p.period.Total()).Msgf("Loaded last metrics data")
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -154,11 +155,13 @@ func (p *Poller[T, AggregateT]) Start() {
|
||||||
gatherErrsTicker.Stop()
|
gatherErrsTicker.Stop()
|
||||||
saveTicker.Stop()
|
saveTicker.Stop()
|
||||||
|
|
||||||
p.save()
|
if err := p.save(); err != nil {
|
||||||
|
l.Err(err).Msg("failed to save metrics data")
|
||||||
|
}
|
||||||
t.Finish(nil)
|
t.Finish(nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logging.Debug().Msgf("Starting poller %s with interval %s", p.name, pollInterval)
|
l.Debug().Dur("interval", pollInterval).Msg("Starting poller")
|
||||||
|
|
||||||
p.pollWithTimeout(t.Context())
|
p.pollWithTimeout(t.Context())
|
||||||
|
|
||||||
|
@ -176,7 +179,7 @@ func (p *Poller[T, AggregateT]) Start() {
|
||||||
case <-gatherErrsTicker.C:
|
case <-gatherErrsTicker.C:
|
||||||
errs, ok := p.gatherErrs()
|
errs, ok := p.gatherErrs()
|
||||||
if ok {
|
if ok {
|
||||||
logging.Error().Msg(errs)
|
log.Error().Msg(errs)
|
||||||
}
|
}
|
||||||
p.clearErrs()
|
p.clearErrs()
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/shirou/gopsutil/v4/cpu"
|
"github.com/shirou/gopsutil/v4/cpu"
|
||||||
"github.com/shirou/gopsutil/v4/disk"
|
"github.com/shirou/gopsutil/v4/disk"
|
||||||
"github.com/shirou/gopsutil/v4/mem"
|
"github.com/shirou/gopsutil/v4/mem"
|
||||||
|
@ -16,7 +17,6 @@ import (
|
||||||
"github.com/shirou/gopsutil/v4/warning"
|
"github.com/shirou/gopsutil/v4/warning"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/metrics/period"
|
"github.com/yusing/go-proxy/internal/metrics/period"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ func getSystemInfo(ctx context.Context, lastResult *SystemInfo) (*SystemInfo, er
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if allWarnings.HasError() {
|
if allWarnings.HasError() {
|
||||||
logging.Warn().Msg(allWarnings.String())
|
log.Warn().Msg(allWarnings.String())
|
||||||
}
|
}
|
||||||
if allErrors.HasError() {
|
if allErrors.HasError() {
|
||||||
return nil, allErrors.Error()
|
return nil, allErrors.Error()
|
||||||
|
@ -195,7 +195,7 @@ func (s *SystemInfo) collectDisksInfo(ctx context.Context, lastResult *SystemInf
|
||||||
if len(s.Disks) == 0 {
|
if len(s.Disks) == 0 {
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
}
|
}
|
||||||
logging.Warn().Msg(errs.String())
|
log.Warn().Msg(errs.String())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func WriteBody(w http.ResponseWriter, body []byte) {
|
func WriteBody(w http.ResponseWriter, body []byte) {
|
||||||
|
@ -14,9 +14,9 @@ func WriteBody(w http.ResponseWriter, body []byte) {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, http.ErrHandlerTimeout),
|
case errors.Is(err, http.ErrHandlerTimeout),
|
||||||
errors.Is(err, context.DeadlineExceeded):
|
errors.Is(err, context.DeadlineExceeded):
|
||||||
logging.Err(err).Msg("timeout writing body")
|
log.Err(err).Msg("timeout writing body")
|
||||||
default:
|
default:
|
||||||
logging.Err(err).Msg("failed to write body")
|
log.Err(err).Msg("failed to write body")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,11 +9,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func warnNoMatchDomains() {
|
func warnNoMatchDomains() {
|
||||||
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
log.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
||||||
}
|
}
|
||||||
|
|
||||||
var warnNoMatchDomainOnce sync.Once
|
var warnNoMatchDomainOnce sync.Once
|
||||||
|
|
|
@ -7,8 +7,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
"github.com/yusing/go-proxy/internal/net/gphttp/loadbalancer/types"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -47,7 +47,7 @@ func New(cfg *Config) *LoadBalancer {
|
||||||
lb := &LoadBalancer{
|
lb := &LoadBalancer{
|
||||||
Config: new(Config),
|
Config: new(Config),
|
||||||
pool: pool.New[Server]("loadbalancer." + cfg.Link),
|
pool: pool.New[Server]("loadbalancer." + cfg.Link),
|
||||||
l: logging.With().Str("name", cfg.Link).Logger(),
|
l: log.With().Str("name", cfg.Link).Logger(),
|
||||||
}
|
}
|
||||||
lb.UpdateConfigIfNeeded(cfg)
|
lb.UpdateConfigIfNeeded(cfg)
|
||||||
return lb
|
return lb
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
type Weight uint16
|
type Weight int
|
||||||
|
|
|
@ -4,14 +4,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||||
return logging.WithLevel(level).
|
return log.WithLevel(level). //nolint:zerologlint
|
||||||
Str("remote", r.RemoteAddr).
|
Str("remote", r.RemoteAddr).
|
||||||
Str("host", r.Host).
|
Str("host", r.Host).
|
||||||
Str("uri", r.Method+" "+r.RequestURI)
|
Str("uri", r.Method+" "+r.RequestURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
|
func LogError(r *http.Request) *zerolog.Event { return reqLogger(r, zerolog.ErrorLevel) }
|
||||||
|
|
|
@ -4,8 +4,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/auth"
|
"github.com/yusing/go-proxy/internal/auth"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
@ -55,7 +55,7 @@ func PreRequest(p Provider, w http.ResponseWriter, r *http.Request) (proceed boo
|
||||||
"FormHTML": p.FormHTML(),
|
"FormHTML": p.FormHTML(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error().Err(err).Msg("failed to execute captcha page")
|
log.Error().Err(err).Msg("failed to execute captcha page")
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
utils.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
|
serialization.MustRegisterValidation("status_code", func(fl validator.FieldLevel) bool {
|
||||||
statusCode := fl.Field().Int()
|
statusCode := fl.Field().Int()
|
||||||
return gphttp.IsStatusCodeValid(int(statusCode))
|
return gphttp.IsStatusCodeValid(int(statusCode))
|
||||||
})
|
})
|
||||||
|
@ -60,7 +60,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||||
ipStr = r.RemoteAddr
|
ipStr = r.RemoteAddr
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(ipStr)
|
ip := net.ParseIP(ipStr)
|
||||||
for _, cidr := range wl.CIDRWhitelistOpts.Allow {
|
for _, cidr := range wl.Allow {
|
||||||
if cidr.Contains(ip) {
|
if cidr.Contains(ip) {
|
||||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||||
allow = true
|
allow = true
|
||||||
|
@ -70,7 +70,7 @@ func (wl *cidrWhitelist) checkIP(w http.ResponseWriter, r *http.Request) bool {
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||||
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.CIDRWhitelistOpts.Allow)
|
wl.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.Allow)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !allow {
|
if !allow {
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
|
||||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||||
"message": testMessage,
|
"message": testMessage,
|
||||||
})
|
})
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
ExpectError(t, serialization.ErrValidationError, err)
|
||||||
})
|
})
|
||||||
t.Run("invalid cidr", func(t *testing.T) {
|
t.Run("invalid cidr", func(t *testing.T) {
|
||||||
_, err := CIDRWhiteList.New(OptionsRaw{
|
_, err := CIDRWhiteList.New(OptionsRaw{
|
||||||
|
@ -56,7 +56,7 @@ func TestCIDRWhitelistValidation(t *testing.T) {
|
||||||
"status_code": 600,
|
"status_code": 600,
|
||||||
"message": testMessage,
|
"message": testMessage,
|
||||||
})
|
})
|
||||||
ExpectError(t, utils.ErrValidationError, err)
|
ExpectError(t, serialization.ErrValidationError, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -9,8 +10,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/atomic"
|
"github.com/yusing/go-proxy/internal/utils/atomic"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
@ -89,21 +90,29 @@ func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
|
cfCIDRsLastUpdate.Store(time.Now().Add(-cfCIDRsUpdateRetryInterval - cfCIDRsUpdateInterval))
|
||||||
logging.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
log.Err(err).Msg("failed to update cloudflare range, retry in " + strutils.FormatDuration(cfCIDRsUpdateRetryInterval))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(cfCIDRs) == 0 {
|
if len(cfCIDRs) == 0 {
|
||||||
logging.Warn().Msg("cloudflare CIDR range is empty")
|
log.Warn().Msg("cloudflare CIDR range is empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cfCIDRsLastUpdate.Store(time.Now())
|
cfCIDRsLastUpdate.Store(time.Now())
|
||||||
logging.Info().Msg("cloudflare CIDR range updated")
|
log.Info().Msg("cloudflare CIDR range updated")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
func fetchUpdateCFIPRange(endpoint string, cfCIDRs *[]*types.CIDR) error {
|
||||||
resp, err := http.Get(endpoint)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req) //nolint:gosec
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
"github.com/yusing/go-proxy/internal/net/gphttp/middleware/errorpage"
|
||||||
|
@ -32,7 +32,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||||
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
if !gphttp.IsSuccess(resp.StatusCode) && (contentType.IsHTML() || contentType.IsPlainText()) {
|
||||||
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
errorPage, ok := errorpage.GetErrorPageByStatus(resp.StatusCode)
|
||||||
if ok {
|
if ok {
|
||||||
logging.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
log.Debug().Msgf("error page for status %d loaded", resp.StatusCode)
|
||||||
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
_, _ = io.Copy(io.Discard, resp.Body) // drain the original body
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
|
||||||
|
@ -40,7 +40,7 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
|
||||||
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
|
||||||
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
|
||||||
} else {
|
} else {
|
||||||
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
log.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
||||||
filename := path[len(StaticFilePathPrefix):]
|
filename := path[len(StaticFilePathPrefix):]
|
||||||
file, ok := errorpage.GetStaticFile(filename)
|
file, ok := errorpage.GetStaticFile(filename)
|
||||||
if !ok {
|
if !ok {
|
||||||
logging.Error().Msg("unable to load resource " + filename)
|
log.Error().Msg("unable to load resource " + filename)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
ext := filepath.Ext(filename)
|
ext := filepath.Ext(filename)
|
||||||
|
@ -68,10 +68,10 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
|
||||||
case ".css":
|
case ".css":
|
||||||
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
|
||||||
default:
|
default:
|
||||||
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
log.Error().Msgf("unexpected file type %q for %s", ext, filename)
|
||||||
}
|
}
|
||||||
if _, err := w.Write(file); err != nil {
|
if _, err := w.Write(file); err != nil {
|
||||||
logging.Err(err).Msg("unable to write resource " + filename)
|
log.Err(err).Msg("unable to write resource " + filename)
|
||||||
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
http.Error(w, "Error page failure", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -6,9 +6,9 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
U "github.com/yusing/go-proxy/internal/utils"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
|
@ -48,7 +48,7 @@ func GetErrorPageByStatus(statusCode int) (content []byte, ok bool) {
|
||||||
func loadContent() {
|
func loadContent() {
|
||||||
files, err := U.ListFiles(errPagesBasePath, 0)
|
files, err := U.ListFiles(errPagesBasePath, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Err(err).Msg("failed to list error page resources")
|
log.Err(err).Msg("failed to list error page resources")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
|
@ -57,11 +57,11 @@ func loadContent() {
|
||||||
}
|
}
|
||||||
content, err := os.ReadFile(file)
|
content, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
log.Warn().Err(err).Msgf("failed to read error page resource %s", file)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
file = path.Base(file)
|
file = path.Base(file)
|
||||||
logging.Info().Msgf("error page resource %s loaded", file)
|
log.Info().Msgf("error page resource %s loaded", file)
|
||||||
fileContentMap.Store(file, content)
|
fileContentMap.Store(file, content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,9 +83,9 @@ func watchDir() {
|
||||||
loadContent()
|
loadContent()
|
||||||
case events.ActionFileDeleted:
|
case events.ActionFileDeleted:
|
||||||
fileContentMap.Delete(filename)
|
fileContentMap.Delete(filename)
|
||||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
log.Warn().Msgf("error page resource %s deleted", filename)
|
||||||
case events.ActionFileRenamed:
|
case events.ActionFileRenamed:
|
||||||
logging.Warn().Msgf("error page resource %s deleted", filename)
|
log.Warn().Msgf("error page resource %s deleted", filename)
|
||||||
fileContentMap.Delete(filename)
|
fileContentMap.Delete(filename)
|
||||||
loadContent()
|
loadContent()
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,11 +8,11 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -87,7 +87,7 @@ func NewMiddleware[ImplType any]() *Middleware {
|
||||||
func (m *Middleware) enableTrace() {
|
func (m *Middleware) enableTrace() {
|
||||||
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
|
||||||
tracer.enableTrace()
|
tracer.enableTrace()
|
||||||
logging.Trace().Msgf("middleware %s enabled trace", m.name)
|
log.Trace().Msgf("middleware %s enabled trace", m.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,14 +118,14 @@ func (m *Middleware) apply(optsRaw OptionsRaw) gperr.Error {
|
||||||
"priority": optsRaw["priority"],
|
"priority": optsRaw["priority"],
|
||||||
"bypass": optsRaw["bypass"],
|
"bypass": optsRaw["bypass"],
|
||||||
}
|
}
|
||||||
if err := utils.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
if err := serialization.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
optsRaw = maps.Clone(optsRaw)
|
optsRaw = maps.Clone(optsRaw)
|
||||||
for k := range commonOpts {
|
for k := range commonOpts {
|
||||||
delete(optsRaw, k)
|
delete(optsRaw, k)
|
||||||
}
|
}
|
||||||
return utils.MapUnmarshalValidate(optsRaw, m.impl)
|
return serialization.MapUnmarshalValidate(optsRaw, m.impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) finalize() error {
|
func (m *Middleware) finalize() error {
|
||||||
|
|
|
@ -3,9 +3,9 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
@ -59,7 +59,7 @@ func LoadComposeFiles() {
|
||||||
errs := gperr.NewBuilder("middleware compile errors")
|
errs := gperr.NewBuilder("middleware compile errors")
|
||||||
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
middlewareDefs, err := utils.ListFiles(common.MiddlewareComposeBasePath, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Err(err).Msg("failed to list middleware definitions")
|
log.Err(err).Msg("failed to list middleware definitions")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, defFile := range middlewareDefs {
|
for _, defFile := range middlewareDefs {
|
||||||
|
@ -75,7 +75,7 @@ func LoadComposeFiles() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
allMiddlewares[name] = m
|
allMiddlewares[name] = m
|
||||||
logging.Info().
|
log.Info().
|
||||||
Str("src", path.Base(defFile)).
|
Str("src", path.Base(defFile)).
|
||||||
Str("name", name).
|
Str("name", name).
|
||||||
Msg("middleware loaded")
|
Msg("middleware loaded")
|
||||||
|
@ -94,7 +94,7 @@ func LoadComposeFiles() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
allMiddlewares[name] = m
|
allMiddlewares[name] = m
|
||||||
logging.Info().
|
log.Info().
|
||||||
Str("src", path.Base(defFile)).
|
Str("src", path.Base(defFile)).
|
||||||
Str("name", name).
|
Str("name", name).
|
||||||
Msg("middleware loaded")
|
Msg("middleware loaded")
|
||||||
|
|
|
@ -25,7 +25,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
"github.com/yusing/go-proxy/internal/logging/accesslog"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
|
@ -138,7 +138,7 @@ func NewReverseProxy(name string, target *types.URL, transport http.RoundTripper
|
||||||
panic("nil transport")
|
panic("nil transport")
|
||||||
}
|
}
|
||||||
rp := &ReverseProxy{
|
rp := &ReverseProxy{
|
||||||
Logger: logging.With().Str("name", name).Logger(),
|
Logger: log.With().Str("name", name).Logger(),
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
TargetName: name,
|
TargetName: name,
|
||||||
TargetURL: target,
|
TargetURL: target,
|
||||||
|
@ -173,17 +173,17 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
||||||
case errors.Is(err, context.Canceled),
|
case errors.Is(err, context.Canceled),
|
||||||
errors.Is(err, io.EOF),
|
errors.Is(err, io.EOF),
|
||||||
errors.Is(err, context.DeadlineExceeded):
|
errors.Is(err, context.DeadlineExceeded):
|
||||||
logging.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
|
log.Debug().Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||||
default:
|
default:
|
||||||
var recordErr tls.RecordHeaderError
|
var recordErr tls.RecordHeaderError
|
||||||
if errors.As(err, &recordErr) {
|
if errors.As(err, &recordErr) {
|
||||||
logging.Error().
|
log.Error().
|
||||||
Str("url", reqURL).
|
Str("url", reqURL).
|
||||||
Msgf(`scheme was likely misconfigured as https,
|
Msgf(`scheme was likely misconfigured as https,
|
||||||
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
|
try setting "proxy.%s.scheme" back to "http"`, p.TargetName)
|
||||||
logging.Err(err).Msg("underlying error")
|
log.Err(err).Msg("underlying error")
|
||||||
} else {
|
} else {
|
||||||
logging.Err(err).Str("url", reqURL).Msg("http proxy error")
|
log.Err(err).Str("url", reqURL).Msg("http proxy error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,7 +220,6 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
transport := p.Transport
|
transport := p.Transport
|
||||||
|
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
/* trunk-ignore(golangci-lint/revive) */
|
|
||||||
if ctx.Done() != nil {
|
if ctx.Done() != nil {
|
||||||
// CloseNotifier predates context.Context, and has been
|
// CloseNotifier predates context.Context, and has been
|
||||||
// entirely superseded by it. If the request contains
|
// entirely superseded by it. If the request contains
|
||||||
|
@ -352,7 +351,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) //nolint:contextcheck
|
||||||
|
|
||||||
res, err := transport.RoundTrip(outreq)
|
res, err := transport.RoundTrip(outreq)
|
||||||
|
|
||||||
|
@ -507,18 +506,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
|
||||||
res.Header = rw.Header()
|
res.Header = rw.Header()
|
||||||
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||||
if err := res.Write(brw); err != nil {
|
if err := res.Write(brw); err != nil {
|
||||||
/* trunk-ignore(golangci-lint/errorlint) */
|
//nolint:errorlint
|
||||||
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
|
p.errorHandler(rw, req, fmt.Errorf("response write: %s", err), true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := brw.Flush(); err != nil {
|
if err := brw.Flush(); err != nil {
|
||||||
/* trunk-ignore(golangci-lint/errorlint) */
|
//nolint:errorlint
|
||||||
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
|
p.errorHandler(rw, req, fmt.Errorf("response flush: %s", err), true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
|
bdp := U.NewBidirectionalPipe(req.Context(), conn, backConn)
|
||||||
/* trunk-ignore(golangci-lint/errcheck) */
|
//nolint:errcheck
|
||||||
bdp.Start()
|
bdp.Start()
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,14 +9,14 @@ import (
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/http3"
|
"github.com/quic-go/quic-go/http3"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/acl"
|
"github.com/yusing/go-proxy/internal/acl"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CertProvider interface {
|
type CertProvider interface {
|
||||||
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
@ -53,7 +53,7 @@ func StartServer(parent task.Parent, opt Options) (s *Server) {
|
||||||
func NewServer(opt Options) (s *Server) {
|
func NewServer(opt Options) (s *Server) {
|
||||||
var httpSer, httpsSer *http.Server
|
var httpSer, httpsSer *http.Server
|
||||||
|
|
||||||
logger := logging.With().Str("server", opt.Name).Logger()
|
logger := log.With().Str("server", opt.Name).Logger()
|
||||||
|
|
||||||
certAvailable := false
|
certAvailable := false
|
||||||
if opt.CertProvider != nil {
|
if opt.CertProvider != nil {
|
||||||
|
|
|
@ -2,7 +2,7 @@ package notif
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
type NotificationConfig struct {
|
type NotificationConfig struct {
|
||||||
|
@ -46,5 +46,5 @@ func (cfg *NotificationConfig) UnmarshalMap(m map[string]any) (err gperr.Error)
|
||||||
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
|
Withf("expect %s or %s", ProviderWebhook, ProviderGotify)
|
||||||
}
|
}
|
||||||
|
|
||||||
return utils.MapUnmarshalValidate(m, cfg.Provider)
|
return serialization.MapUnmarshalValidate(m, cfg.Provider)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -181,7 +181,7 @@ func TestNotificationConfig(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var cfg NotificationConfig
|
var cfg NotificationConfig
|
||||||
provider := tt.cfg["provider"]
|
provider := tt.cfg["provider"]
|
||||||
err := utils.MapUnmarshalValidate(tt.cfg, &cfg)
|
err := serialization.MapUnmarshalValidate(tt.cfg, &cfg)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
ExpectHasError(t, err)
|
ExpectHasError(t, err)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -8,14 +8,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Provider interface {
|
Provider interface {
|
||||||
utils.CustomValidator
|
serialization.CustomValidator
|
||||||
|
|
||||||
GetName() string
|
GetName() string
|
||||||
GetURL() string
|
GetURL() string
|
||||||
|
@ -73,7 +73,7 @@ func (msg *LogMessage) notify(ctx context.Context, provider Provider) error {
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent:
|
case http.StatusOK, http.StatusCreated, http.StatusAccepted, http.StatusNoContent:
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
logging.Debug().
|
log.Debug().
|
||||||
Str("provider", provider.GetName()).
|
Str("provider", provider.GetName()).
|
||||||
Str("url", provider.GetURL()).
|
Str("url", provider.GetURL()).
|
||||||
Str("status", resp.Status).
|
Str("status", resp.Status).
|
||||||
|
|
|
@ -7,12 +7,12 @@ import (
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/route"
|
"github.com/yusing/go-proxy/internal/route"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
"github.com/yusing/go-proxy/internal/watcher"
|
"github.com/yusing/go-proxy/internal/watcher"
|
||||||
)
|
)
|
||||||
|
@ -36,7 +36,7 @@ func DockerProviderImpl(name, dockerHost string) ProviderImpl {
|
||||||
return &DockerProvider{
|
return &DockerProvider{
|
||||||
name,
|
name,
|
||||||
dockerHost,
|
dockerHost,
|
||||||
logging.With().Str("type", "docker").Str("name", name).Logger(),
|
log.With().Str("type", "docker").Str("name", name).Logger(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
|
||||||
// Always non-nil.
|
// Always non-nil.
|
||||||
func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) {
|
func (p *DockerProvider) routesFromContainerLabels(container *docker.Container) (route.Routes, gperr.Error) {
|
||||||
if !container.IsExplicit && p.IsExplicitOnly() {
|
if !container.IsExplicit && p.IsExplicitOnly() {
|
||||||
return nil, nil
|
return make(route.Routes, 0), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := make(route.Routes, len(container.Aliases))
|
routes := make(route.Routes, len(container.Aliases))
|
||||||
|
@ -180,7 +180,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
|
||||||
}
|
}
|
||||||
|
|
||||||
// deserialize map into entry object
|
// deserialize map into entry object
|
||||||
err := U.MapUnmarshalValidate(entryMap, r)
|
err := serialization.MapUnmarshalValidate(entryMap, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.Add(err.Subject(alias))
|
errs.Add(err.Subject(alias))
|
||||||
} else {
|
} else {
|
||||||
|
@ -189,7 +189,7 @@ func (p *DockerProvider) routesFromContainerLabels(container *docker.Container)
|
||||||
}
|
}
|
||||||
if wildcardProps != nil {
|
if wildcardProps != nil {
|
||||||
for _, re := range routes {
|
for _, re := range routes {
|
||||||
if err := U.MapUnmarshalValidate(wildcardProps, re); err != nil {
|
if err := serialization.MapUnmarshalValidate(wildcardProps, re); err != nil {
|
||||||
errs.Add(err.Subject(docker.WildcardAlias))
|
errs.Add(err.Subject(docker.WildcardAlias))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,11 +6,11 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/route"
|
"github.com/yusing/go-proxy/internal/route"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
W "github.com/yusing/go-proxy/internal/watcher"
|
W "github.com/yusing/go-proxy/internal/watcher"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
|
||||||
impl := &FileProvider{
|
impl := &FileProvider{
|
||||||
fileName: filename,
|
fileName: filename,
|
||||||
path: path.Join(common.ConfigBasePath, filename),
|
path: path.Join(common.ConfigBasePath, filename),
|
||||||
l: logging.With().Str("type", "file").Str("name", filename).Logger(),
|
l: log.With().Str("type", "file").Str("name", filename).Logger(),
|
||||||
}
|
}
|
||||||
_, err := os.Stat(impl.path)
|
_, err := os.Stat(impl.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -34,7 +34,7 @@ func FileProviderImpl(filename string) (ProviderImpl, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func validate(data []byte) (routes route.Routes, err gperr.Error) {
|
func validate(data []byte) (routes route.Routes, err gperr.Error) {
|
||||||
err = utils.UnmarshalValidateYAML(data, &routes)
|
err = serialization.UnmarshalValidateYAML(data, &routes)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/docker/docker/api/types/container"
|
"github.com/docker/docker/api/types/container"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/agent/pkg/agent"
|
"github.com/yusing/go-proxy/agent/pkg/agent"
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/homepage"
|
"github.com/yusing/go-proxy/internal/homepage"
|
||||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
netutils "github.com/yusing/go-proxy/internal/net"
|
netutils "github.com/yusing/go-proxy/internal/net"
|
||||||
net "github.com/yusing/go-proxy/internal/net/types"
|
net "github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/proxmox"
|
"github.com/yusing/go-proxy/internal/proxmox"
|
||||||
|
@ -116,7 +116,7 @@ func (r *Route) Validate() gperr.Error {
|
||||||
Subject(containerName)
|
Subject(containerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := logging.With().Str("container", containerName).Logger()
|
l := log.With().Str("container", containerName).Logger()
|
||||||
|
|
||||||
l.Info().Msg("checking if container is running")
|
l.Info().Msg("checking if container is running")
|
||||||
running, err := node.LXCIsRunning(ctx, vmid)
|
running, err := node.LXCIsRunning(ctx, vmid)
|
||||||
|
|
|
@ -3,7 +3,7 @@ package rules
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ func TestParseRule(t *testing.T) {
|
||||||
var rules struct {
|
var rules struct {
|
||||||
Rules Rules
|
Rules Rules
|
||||||
}
|
}
|
||||||
err := utils.MapUnmarshalValidate(utils.SerializedObject{"rules": test}, &rules)
|
err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectEqual(t, len(rules.Rules), len(test))
|
ExpectEqual(t, len(rules.Rules), len(test))
|
||||||
ExpectEqual(t, rules.Rules[0].Name, "test")
|
ExpectEqual(t, rules.Rules[0].Name, "test")
|
||||||
|
|
|
@ -5,9 +5,9 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/idlewatcher"
|
"github.com/yusing/go-proxy/internal/idlewatcher"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
net "github.com/yusing/go-proxy/internal/net/types"
|
net "github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/route/routes"
|
"github.com/yusing/go-proxy/internal/route/routes"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -32,7 +32,7 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
|
||||||
// TODO: support non-coherent scheme
|
// TODO: support non-coherent scheme
|
||||||
return &StreamRoute{
|
return &StreamRoute{
|
||||||
Route: base,
|
Route: base,
|
||||||
l: logging.With().
|
l: log.With().
|
||||||
Str("type", string(base.Scheme)).
|
Str("type", string(base.Scheme)).
|
||||||
Str("name", base.Name()).
|
Str("name", base.Name()).
|
||||||
Logger(),
|
Logger(),
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
. "github.com/yusing/go-proxy/internal/route"
|
. "github.com/yusing/go-proxy/internal/route"
|
||||||
route "github.com/yusing/go-proxy/internal/route/types"
|
route "github.com/yusing/go-proxy/internal/route/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/serialization"
|
||||||
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
expect "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cfg := Route{}
|
cfg := Route{}
|
||||||
tt.input["host"] = "internal"
|
tt.input["host"] = "internal"
|
||||||
err := utils.MapUnmarshalValidate(tt.input, &cfg)
|
err := serialization.MapUnmarshalValidate(tt.input, &cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/types"
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
@ -88,7 +88,7 @@ func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) {
|
||||||
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
|
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
|
||||||
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
|
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
|
log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ func (conn *UDPConn) read() (err error) {
|
||||||
conn.buf.oobn = 0
|
conn.buf.oobn = 0
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,7 @@ func (conn *UDPConn) read() (err error) {
|
||||||
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
|
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
|
||||||
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
|
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
|
log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -120,12 +120,12 @@ func (conn *UDPConn) write() (err error) {
|
||||||
case *net.UDPConn:
|
case *net.UDPConn:
|
||||||
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
|
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
|
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logging.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
|
log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package utils
|
package serialization
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -12,7 +12,9 @@ import (
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/functional"
|
"github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
@ -40,14 +42,14 @@ var (
|
||||||
|
|
||||||
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()
|
||||||
|
|
||||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
var defaultValues = xsync.NewMapOf[reflect.Type, func() any]()
|
||||||
|
|
||||||
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
func RegisterDefaultValueFactory[T any](factory func() *T) {
|
||||||
t := reflect.TypeFor[T]()
|
t := reflect.TypeFor[T]()
|
||||||
if t.Kind() == reflect.Ptr {
|
if t.Kind() == reflect.Ptr {
|
||||||
panic("pointer of pointer")
|
panic("pointer of pointer")
|
||||||
}
|
}
|
||||||
if defaultValues.Has(t) {
|
if _, ok := defaultValues.Load(t); ok {
|
||||||
panic("default value for " + t.String() + " already registered")
|
panic("default value for " + t.String() + " already registered")
|
||||||
}
|
}
|
||||||
defaultValues.Store(t, func() any { return factory() })
|
defaultValues.Store(t, func() any { return factory() })
|
||||||
|
@ -259,7 +261,7 @@ func mapUnmarshalValidate(src SerializedObject, dst any, checkValidateTag bool)
|
||||||
errs.Add(err.Subject(k))
|
errs.Add(err.Subject(k))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(NearestField(k, mapping))))
|
errs.Add(ErrUnknownField.Subject(k).With(gperr.DoYouMean(utils.NearestField(k, mapping))))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasValidateTag && checkValidateTag {
|
if hasValidateTag && checkValidateTag {
|
|
@ -1,4 +1,4 @@
|
||||||
package utils
|
package serialization
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
|
@ -1,4 +1,4 @@
|
||||||
package utils
|
package serialization
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
|
@ -7,9 +7,9 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -116,14 +116,14 @@ func (t *Task) Finish(reason any) {
|
||||||
func (t *Task) finish(reason any) {
|
func (t *Task) finish(reason any) {
|
||||||
t.cancel(fmtCause(reason))
|
t.cancel(fmtCause(reason))
|
||||||
if !waitWithTimeout(t.childrenDone) {
|
if !waitWithTimeout(t.childrenDone) {
|
||||||
logging.Debug().
|
log.Debug().
|
||||||
Str("task", t.name).
|
Str("task", t.name).
|
||||||
Strs("subtasks", t.listChildren()).
|
Strs("subtasks", t.listChildren()).
|
||||||
Msg("Timeout waiting for subtasks to finish")
|
Msg("Timeout waiting for subtasks to finish")
|
||||||
}
|
}
|
||||||
go t.runCallbacks()
|
go t.runCallbacks()
|
||||||
if !waitWithTimeout(t.callbacksDone) {
|
if !waitWithTimeout(t.callbacksDone) {
|
||||||
logging.Debug().
|
log.Debug().
|
||||||
Str("task", t.name).
|
Str("task", t.name).
|
||||||
Strs("callbacks", t.listCallbacks()).
|
Strs("callbacks", t.listCallbacks()).
|
||||||
Msg("Timeout waiting for callbacks to finish")
|
Msg("Timeout waiting for callbacks to finish")
|
||||||
|
@ -134,7 +134,7 @@ func (t *Task) finish(reason any) {
|
||||||
}
|
}
|
||||||
t.parent.subChildCount()
|
t.parent.subChildCount()
|
||||||
allTasks.Remove(t)
|
allTasks.Remove(t)
|
||||||
logging.Trace().Msg("task " + t.name + " finished")
|
log.Trace().Msg("task " + t.name + " finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||||
|
@ -166,7 +166,7 @@ func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Trace().Msg("task " + child.name + " started")
|
log.Trace().Msg("task " + child.name + " started")
|
||||||
return child
|
return child
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ func (t *Task) MarshalText() ([]byte, error) {
|
||||||
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
logging.Error().
|
log.Error().
|
||||||
Interface("err", err).
|
Interface("err", err).
|
||||||
Msg("panic in task " + t.name + "." + caller)
|
Msg("panic in task " + t.name + "." + caller)
|
||||||
if common.IsDebug {
|
if common.IsDebug {
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,10 +68,10 @@ func GracefulShutdown(timeout time.Duration) (err error) {
|
||||||
case <-after:
|
case <-after:
|
||||||
b, err := json.Marshal(DebugTaskList())
|
b, err := json.Marshal(DebugTaskList())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Warn().Err(err).Msg("failed to marshal tasks")
|
log.Warn().Err(err).Msg("failed to marshal tasks")
|
||||||
return context.DeadlineExceeded
|
return context.DeadlineExceeded
|
||||||
}
|
}
|
||||||
logging.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
log.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||||
return context.DeadlineExceeded
|
return context.DeadlineExceeded
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,6 @@ func WaitExit(shutdownTimeout int) {
|
||||||
<-sig
|
<-sig
|
||||||
|
|
||||||
// gracefully shutdown
|
// gracefully shutdown
|
||||||
logging.Info().Msg("shutting down")
|
log.Info().Msg("shutting down")
|
||||||
_ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout))
|
_ = GracefulShutdown(time.Second * time.Duration(shutdownTimeout))
|
||||||
}
|
}
|
||||||
|
|
21
internal/utils/go.mod
Normal file
21
internal/utils/go.mod
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
module github.com/yusing/go-proxy/internal/utils
|
||||||
|
|
||||||
|
go 1.24.3
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/goccy/go-yaml v1.17.1
|
||||||
|
github.com/puzpuzpuz/xsync/v4 v4.1.0
|
||||||
|
github.com/rs/zerolog v1.34.0
|
||||||
|
github.com/stretchr/testify v1.10.0
|
||||||
|
go.uber.org/atomic v1.11.0
|
||||||
|
golang.org/x/text v0.25.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
36
internal/utils/go.sum
Normal file
36
internal/utils/go.sum
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY=
|
||||||
|
github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/puzpuzpuz/xsync/v4 v4.1.0 h1:x9eHRl4QhZFIPJ17yl4KKW9xLyVWbb3/Yq4SXpjF71U=
|
||||||
|
github.com/puzpuzpuz/xsync/v4 v4.1.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||||
|
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||||
|
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/synk"
|
"github.com/yusing/go-proxy/internal/utils/synk"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -91,20 +90,20 @@ func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.Re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p BidirectionalPipe) Start() gperr.Error {
|
func (p BidirectionalPipe) Start() error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
b := gperr.NewBuilder("bidirectional pipe error")
|
var srcErr, dstErr error
|
||||||
go func() {
|
go func() {
|
||||||
b.Add(p.pSrcDst.Start())
|
srcErr = p.pSrcDst.Start()
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
b.Add(p.pDstSrc.Start())
|
dstErr = p.pDstSrc.Start()
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return b.Error()
|
return errors.Join(srcErr, dstErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpFlusher interface {
|
type httpFlusher interface {
|
||||||
|
@ -143,30 +142,18 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) {
|
||||||
wCloser, wCanClose := dst.Writer.(io.Closer)
|
wCloser, wCanClose := dst.Writer.(io.Closer)
|
||||||
rCloser, rCanClose := src.Reader.(io.Closer)
|
rCloser, rCanClose := src.Reader.(io.Closer)
|
||||||
if wCanClose || rCanClose {
|
if wCanClose || rCanClose {
|
||||||
if src.ctx == dst.ctx {
|
go func() {
|
||||||
go func() {
|
select {
|
||||||
<-src.ctx.Done()
|
case <-src.ctx.Done():
|
||||||
if wCanClose {
|
case <-dst.ctx.Done():
|
||||||
wCloser.Close()
|
|
||||||
}
|
|
||||||
if rCanClose {
|
|
||||||
rCloser.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
if wCloser != nil {
|
|
||||||
go func() {
|
|
||||||
<-src.ctx.Done()
|
|
||||||
wCloser.Close()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
if rCloser != nil {
|
if rCanClose {
|
||||||
go func() {
|
defer rCloser.Close()
|
||||||
<-dst.ctx.Done()
|
|
||||||
rCloser.Close()
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
if wCanClose {
|
||||||
|
defer wCloser.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
flusher := getHTTPFlusher(dst.Writer)
|
flusher := getHTTPFlusher(dst.Writer)
|
||||||
canFlush := flusher != nil
|
canFlush := flusher != nil
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -29,12 +29,12 @@ func (p Pool[T]) Name() string {
|
||||||
func (p Pool[T]) Add(obj T) {
|
func (p Pool[T]) Add(obj T) {
|
||||||
p.checkExists(obj.Key())
|
p.checkExists(obj.Key())
|
||||||
p.m.Store(obj.Key(), obj)
|
p.m.Store(obj.Key(), obj)
|
||||||
logging.Info().Msgf("%s: added %s", p.name, obj.Name())
|
log.Info().Msgf("%s: added %s", p.name, obj.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Pool[T]) Del(obj T) {
|
func (p Pool[T]) Del(obj T) {
|
||||||
p.m.Delete(obj.Key())
|
p.m.Delete(obj.Key())
|
||||||
logging.Info().Msgf("%s: removed %s", p.name, obj.Name())
|
log.Info().Msgf("%s: removed %s", p.name, obj.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Pool[T]) Get(key string) (T, bool) {
|
func (p Pool[T]) Get(key string) (T, bool) {
|
||||||
|
|
|
@ -5,11 +5,11 @@ package pool
|
||||||
import (
|
import (
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p Pool[T]) checkExists(key string) {
|
func (p Pool[T]) checkExists(key string) {
|
||||||
if _, ok := p.m.Load(key); ok {
|
if _, ok := p.m.Load(key); ok {
|
||||||
logging.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
|
log.Warn().Msgf("%s: key %s already exists\nstacktrace: %s", p.name, key, string(debug.Stack()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,35 @@
|
||||||
package synk
|
package synk
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"runtime"
|
||||||
"os/signal"
|
"unsafe"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type weakBuf = unsafe.Pointer
|
||||||
|
|
||||||
|
func makeWeak(b *[]byte) weakBuf {
|
||||||
|
ptr := runtime_registerWeakPointer(unsafe.Pointer(b))
|
||||||
|
runtime.KeepAlive(ptr)
|
||||||
|
addCleanup(b, addGCed, cap(*b))
|
||||||
|
return weakBuf(ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBufFromWeak(w weakBuf) []byte {
|
||||||
|
ptr := (*[]byte)(runtime_makeStrongFromWeak(w))
|
||||||
|
if ptr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return *ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname runtime_registerWeakPointer weak.runtime_registerWeakPointer
|
||||||
|
func runtime_registerWeakPointer(unsafe.Pointer) unsafe.Pointer
|
||||||
|
|
||||||
|
//go:linkname runtime_makeStrongFromWeak weak.runtime_makeStrongFromWeak
|
||||||
|
func runtime_makeStrongFromWeak(unsafe.Pointer) unsafe.Pointer
|
||||||
|
|
||||||
type BytesPool struct {
|
type BytesPool struct {
|
||||||
pool chan []byte
|
pool chan weakBuf
|
||||||
initSize int
|
initSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,19 +41,15 @@ const (
|
||||||
const (
|
const (
|
||||||
InPoolLimit = 32 * mb
|
InPoolLimit = 32 * mb
|
||||||
|
|
||||||
DefaultInitBytes = 32 * kb
|
DefaultInitBytes = 4 * kb
|
||||||
PoolThreshold = 64 * kb
|
PoolThreshold = 256 * kb
|
||||||
DropThresholdHigh = 4 * mb
|
DropThresholdHigh = 4 * mb
|
||||||
|
|
||||||
PoolSize = InPoolLimit / PoolThreshold
|
PoolSize = InPoolLimit / PoolThreshold
|
||||||
|
|
||||||
CleanupInterval = 5 * time.Second
|
|
||||||
MaxDropsPerCycle = 10
|
|
||||||
MaxChecksPerCycle = 100
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var bytesPool = &BytesPool{
|
var bytesPool = &BytesPool{
|
||||||
pool: make(chan []byte, PoolSize),
|
pool: make(chan weakBuf, PoolSize),
|
||||||
initSize: DefaultInitBytes,
|
initSize: DefaultInitBytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,12 +58,18 @@ func NewBytesPool() *BytesPool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BytesPool) Get() []byte {
|
func (p *BytesPool) Get() []byte {
|
||||||
select {
|
for {
|
||||||
case b := <-p.pool:
|
select {
|
||||||
subInPoolSize(int64(cap(b)))
|
case bWeak := <-p.pool:
|
||||||
return b
|
bPtr := getBufFromWeak(bWeak)
|
||||||
default:
|
if bPtr == nil {
|
||||||
return make([]byte, 0, p.initSize)
|
continue
|
||||||
|
}
|
||||||
|
addReused(cap(bPtr))
|
||||||
|
return bPtr
|
||||||
|
default:
|
||||||
|
return make([]byte, 0, p.initSize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,90 +77,43 @@ func (p *BytesPool) GetSized(size int) []byte {
|
||||||
if size <= PoolThreshold {
|
if size <= PoolThreshold {
|
||||||
return make([]byte, size)
|
return make([]byte, size)
|
||||||
}
|
}
|
||||||
select {
|
for {
|
||||||
case b := <-p.pool:
|
|
||||||
if size <= cap(b) {
|
|
||||||
subInPoolSize(int64(cap(b)))
|
|
||||||
return b[:size]
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case p.pool <- b:
|
case bWeak := <-p.pool:
|
||||||
addInPoolSize(int64(cap(b)))
|
bPtr := getBufFromWeak(bWeak)
|
||||||
|
if bPtr == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
capB := cap(bPtr)
|
||||||
|
if capB >= size {
|
||||||
|
addReused(capB)
|
||||||
|
return (bPtr)[:size]
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case p.pool <- bWeak:
|
||||||
|
default:
|
||||||
|
// just drop it
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
default:
|
return make([]byte, size)
|
||||||
}
|
}
|
||||||
return make([]byte, size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BytesPool) Put(b []byte) {
|
func (p *BytesPool) Put(b []byte) {
|
||||||
size := cap(b)
|
size := cap(b)
|
||||||
if size > DropThresholdHigh || poolFull() {
|
if size <= PoolThreshold || size > DropThresholdHigh {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b = b[:0]
|
b = b[:0]
|
||||||
|
w := makeWeak(&b)
|
||||||
select {
|
select {
|
||||||
case p.pool <- b:
|
case p.pool <- w:
|
||||||
addInPoolSize(int64(size))
|
|
||||||
return
|
|
||||||
default:
|
default:
|
||||||
// just drop it
|
// just drop it
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var inPoolSize int64
|
|
||||||
|
|
||||||
func addInPoolSize(size int64) {
|
|
||||||
atomic.AddInt64(&inPoolSize, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func subInPoolSize(size int64) {
|
|
||||||
atomic.AddInt64(&inPoolSize, -size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Periodically drop some buffers to prevent excessive memory usage
|
initPoolStats()
|
||||||
go func() {
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, os.Interrupt)
|
|
||||||
|
|
||||||
cleanupTicker := time.NewTicker(CleanupInterval)
|
|
||||||
defer cleanupTicker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-cleanupTicker.C:
|
|
||||||
dropBuffers()
|
|
||||||
case <-sigCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func poolFull() bool {
|
|
||||||
return atomic.LoadInt64(&inPoolSize) >= InPoolLimit
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropBuffers removes excess buffers from the pool when it grows too large.
|
|
||||||
func dropBuffers() {
|
|
||||||
// Check if pool has more than a threshold of buffers
|
|
||||||
count := 0
|
|
||||||
droppedSize := 0
|
|
||||||
checks := 0
|
|
||||||
for count < MaxDropsPerCycle && checks < MaxChecksPerCycle && atomic.LoadInt64(&inPoolSize) > InPoolLimit*2/3 {
|
|
||||||
select {
|
|
||||||
case b := <-bytesPool.pool:
|
|
||||||
n := cap(b)
|
|
||||||
subInPoolSize(int64(n))
|
|
||||||
droppedSize += n
|
|
||||||
count++
|
|
||||||
default:
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
checks++
|
|
||||||
}
|
|
||||||
if count > 0 {
|
|
||||||
logging.Debug().Int("dropped", count).Int("size", droppedSize).Msg("dropped buffers from pool")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024}
|
var sizes = []int{1024, 4096, 16384, 65536, 32 * 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024}
|
||||||
|
|
||||||
func BenchmarkBytesPool_GetSmall(b *testing.B) {
|
func BenchmarkBytesPool_GetSmall(b *testing.B) {
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
|
|
55
internal/utils/synk/pool_debug.go
Normal file
55
internal/utils/synk/pool_debug.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
//go:build !production
|
||||||
|
|
||||||
|
package synk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
numReused, sizeReused uint64
|
||||||
|
numGCed, sizeGCed uint64
|
||||||
|
)
|
||||||
|
|
||||||
|
func addReused(size int) {
|
||||||
|
atomic.AddUint64(&numReused, 1)
|
||||||
|
atomic.AddUint64(&sizeReused, uint64(size))
|
||||||
|
}
|
||||||
|
|
||||||
|
func addGCed(size int) {
|
||||||
|
atomic.AddUint64(&numGCed, 1)
|
||||||
|
atomic.AddUint64(&sizeGCed, uint64(size))
|
||||||
|
}
|
||||||
|
|
||||||
|
var addCleanup = runtime.AddCleanup[[]byte, int]
|
||||||
|
|
||||||
|
func initPoolStats() {
|
||||||
|
go func() {
|
||||||
|
statsTicker := time.NewTicker(5 * time.Second)
|
||||||
|
defer statsTicker.Stop()
|
||||||
|
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, os.Interrupt)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sig:
|
||||||
|
return
|
||||||
|
case <-statsTicker.C:
|
||||||
|
log.Info().
|
||||||
|
Uint64("numReused", atomic.LoadUint64(&numReused)).
|
||||||
|
Str("sizeReused", strutils.FormatByteSize(atomic.LoadUint64(&sizeReused))).
|
||||||
|
Uint64("numGCed", atomic.LoadUint64(&numGCed)).
|
||||||
|
Str("sizeGCed", strutils.FormatByteSize(atomic.LoadUint64(&sizeGCed))).
|
||||||
|
Msg("bytes pool stats")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
8
internal/utils/synk/pool_prod.go
Normal file
8
internal/utils/synk/pool_prod.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
//go:build production
|
||||||
|
|
||||||
|
package synk
|
||||||
|
|
||||||
|
func addReused(size int) {}
|
||||||
|
func addGCed(size int) {}
|
||||||
|
func initPoolStats() {}
|
||||||
|
func addCleanup(ptr *[]byte, cleanup func(int), arg int) {}
|
|
@ -2,14 +2,16 @@ package expect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var isTest = strings.HasSuffix(os.Args[0], ".test")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if common.IsTest {
|
if isTest {
|
||||||
// force verbose output
|
// force verbose output
|
||||||
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,19 +1,11 @@
|
||||||
package expect
|
package expect
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
if common.IsTest {
|
|
||||||
os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExpectNoError(t *testing.T, err error) {
|
func ExpectNoError(t *testing.T, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -3,7 +3,6 @@ package utils
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,8 +37,6 @@ func init() {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-task.RootContext().Done():
|
|
||||||
return
|
|
||||||
case <-timeNowTicker.C:
|
case <-timeNowTicker.C:
|
||||||
shouldCallTimeNow.Store(true)
|
shouldCallTimeNow.Store(true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||||
)
|
)
|
||||||
|
@ -41,13 +41,13 @@ func NewDirectoryWatcher(parent task.Parent, dirPath string) *DirWatcher {
|
||||||
//! subdirectories are not watched
|
//! subdirectories are not watched
|
||||||
w, err := fsnotify.NewWatcher()
|
w, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Panic().Err(err).Msg("unable to create fs watcher")
|
log.Panic().Err(err).Msg("unable to create fs watcher")
|
||||||
}
|
}
|
||||||
if err = w.Add(dirPath); err != nil {
|
if err = w.Add(dirPath); err != nil {
|
||||||
logging.Panic().Err(err).Msg("unable to create fs watcher")
|
log.Panic().Err(err).Msg("unable to create fs watcher")
|
||||||
}
|
}
|
||||||
helper := &DirWatcher{
|
helper := &DirWatcher{
|
||||||
Logger: logging.With().
|
Logger: log.With().
|
||||||
Str("type", "dir").
|
Str("type", "dir").
|
||||||
Str("path", dirPath).
|
Str("path", dirPath).
|
||||||
Logger(),
|
Logger(),
|
||||||
|
|
|
@ -8,9 +8,9 @@ import (
|
||||||
docker_events "github.com/docker/docker/api/types/events"
|
docker_events "github.com/docker/docker/api/types/events"
|
||||||
"github.com/docker/docker/api/types/filters"
|
"github.com/docker/docker/api/types/filters"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/watcher/events"
|
"github.com/yusing/go-proxy/internal/watcher/events"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cEventCh, cErrCh := client.Events(ctx, options)
|
cEventCh, cErrCh := client.Events(ctx, options)
|
||||||
defer logging.Debug().Str("host", client.Address()).Msg("docker watcher closed")
|
defer log.Debug().Str("host", client.Address()).Msg("docker watcher closed")
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -153,7 +153,7 @@ func checkConnection(ctx context.Context, client *docker.SharedClient) bool {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := client.CheckConnection(ctx)
|
err := client.CheckConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Debug().Err(err).Msg("docker watcher: connection failed")
|
log.Debug().Err(err).Msg("docker watcher: connection failed")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -7,9 +7,9 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/go-proxy/internal/docker"
|
"github.com/yusing/go-proxy/internal/docker"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/notif"
|
"github.com/yusing/go-proxy/internal/notif"
|
||||||
"github.com/yusing/go-proxy/internal/route/routes"
|
"github.com/yusing/go-proxy/internal/route/routes"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
@ -48,7 +48,7 @@ func NewMonitor(r routes.Route) health.HealthMonCheck {
|
||||||
case routes.StreamRoute:
|
case routes.StreamRoute:
|
||||||
mon = NewRawHealthMonitor(&r.TargetURL().URL, r.HealthCheckConfig())
|
mon = NewRawHealthMonitor(&r.TargetURL().URL, r.HealthCheckConfig())
|
||||||
default:
|
default:
|
||||||
logging.Panic().Msgf("unexpected route type: %T", r)
|
log.Panic().Msgf("unexpected route type: %T", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if r.IsDocker() {
|
if r.IsDocker() {
|
||||||
|
@ -91,7 +91,7 @@ func (mon *monitor) Start(parent task.Parent) gperr.Error {
|
||||||
mon.task = parent.Subtask("health_monitor", true)
|
mon.task = parent.Subtask("health_monitor", true)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
logger := logging.With().Str("name", mon.service).Logger()
|
logger := log.With().Str("name", mon.service).Logger()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if mon.status.Load() != health.StatusError {
|
if mon.status.Load() != health.StatusError {
|
||||||
|
@ -221,7 +221,7 @@ func (mon *monitor) MarshalJSON() ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mon *monitor) checkUpdateHealth() error {
|
func (mon *monitor) checkUpdateHealth() error {
|
||||||
logger := logging.With().Str("name", mon.Name()).Logger()
|
logger := log.With().Str("name", mon.Name()).Logger()
|
||||||
result, err := mon.checkHealth()
|
result, err := mon.checkHealth()
|
||||||
|
|
||||||
var lastStatus health.Status
|
var lastStatus health.Status
|
||||||
|
|
|
@ -18,7 +18,7 @@ func GetLastVersion() Version {
|
||||||
|
|
||||||
func GetVersionHTTPHandler() http.HandlerFunc {
|
func GetVersionHTTPHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(GetVersion().String()))
|
fmt.Fprint(w, GetVersion().String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,16 +34,16 @@ func init() {
|
||||||
// lastVersion = ParseVersion(lastVersionStr)
|
// lastVersion = ParseVersion(lastVersionStr)
|
||||||
// }
|
// }
|
||||||
// if err != nil && !os.IsNotExist(err) {
|
// if err != nil && !os.IsNotExist(err) {
|
||||||
// logging.Warn().Err(err).Msg("failed to read version file")
|
// log.Warn().Err(err).Msg("failed to read version file")
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
// if err := f.Truncate(0); err != nil {
|
// if err := f.Truncate(0); err != nil {
|
||||||
// logging.Warn().Err(err).Msg("failed to truncate version file")
|
// log.Warn().Err(err).Msg("failed to truncate version file")
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
// _, err = f.WriteString(version)
|
// _, err = f.WriteString(version)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// logging.Warn().Err(err).Msg("failed to save version file")
|
// log.Warn().Err(err).Msg("failed to save version file")
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,4 +2,19 @@ module github.com/yusing/go-proxy/socketproxy
|
||||||
|
|
||||||
go 1.24.3
|
go 1.24.3
|
||||||
|
|
||||||
require github.com/gorilla/mux v1.8.1
|
replace github.com/yusing/go-proxy/internal/utils => ../internal/utils
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/gorilla/mux v1.8.1
|
||||||
|
github.com/yusing/go-proxy/internal/utils v0.0.0-00010101000000-000000000000
|
||||||
|
golang.org/x/net v0.40.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/rs/zerolog v1.34.0 // indirect
|
||||||
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
|
golang.org/x/text v0.25.0 // indirect
|
||||||
|
)
|
||||||
|
|
|
@ -1,2 +1,34 @@
|
||||||
|
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
|
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||||
|
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||||
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
|
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
|
||||||
|
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||||
|
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||||
|
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
|
@ -4,30 +4,33 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/yusing/go-proxy/socketproxy/pkg/reverseproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
|
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
|
||||||
|
|
||||||
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
|
func dialDockerSocket(socket string) func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
return dialer.DialContext(ctx, "unix", DockerSocket)
|
return func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
return dialer.DialContext(ctx, "unix", socket)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var DockerSocketHandler = dockerSocketHandler
|
var DockerSocketHandler = dockerSocketHandler
|
||||||
|
|
||||||
func dockerSocketHandler() http.HandlerFunc {
|
func dockerSocketHandler(socket string) http.HandlerFunc {
|
||||||
rp := &httputil.ReverseProxy{
|
rp := &reverseproxy.ReverseProxy{
|
||||||
Director: func(req *http.Request) {
|
Director: func(req *http.Request) {
|
||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
req.URL.Host = "api.moby.localhost"
|
req.URL.Host = "api.moby.localhost"
|
||||||
req.RequestURI = req.URL.String()
|
req.RequestURI = req.URL.String()
|
||||||
},
|
},
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
DialContext: dialDockerSocket,
|
DialContext: dialDockerSocket(socket),
|
||||||
|
DisableCompression: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +44,7 @@ func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
|
||||||
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
|
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
|
||||||
func NewHandler() http.Handler {
|
func NewHandler() http.Handler {
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
socketHandler := DockerSocketHandler()
|
socketHandler := DockerSocketHandler(DockerSocket)
|
||||||
|
|
||||||
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
|
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
|
||||||
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
|
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
. "github.com/yusing/go-proxy/socketproxy/pkg"
|
. "github.com/yusing/go-proxy/socketproxy/pkg"
|
||||||
)
|
)
|
||||||
|
|
||||||
func mockDockerSocketHandler() http.HandlerFunc {
|
func mockDockerSocketHandler(_ string) http.HandlerFunc {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("mock docker response"))
|
w.Write([]byte("mock docker response"))
|
||||||
|
|
367
socket-proxy/pkg/reverseproxy/reverse_proxy.go
Normal file
367
socket-proxy/pkg/reverseproxy/reverse_proxy.go
Normal file
|
@ -0,0 +1,367 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
// License URL: https://cs.opensource.google/go/go/+/master:LICENSE
|
||||||
|
|
||||||
|
// HTTP reverse proxy handler
|
||||||
|
|
||||||
|
package reverseproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"net/textproto"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
"golang.org/x/net/http/httpguts"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
||||||
|
// sends it to another server, proxying the response back to the
|
||||||
|
// client.
|
||||||
|
//
|
||||||
|
// 1xx responses are forwarded to the client if the underlying
|
||||||
|
// transport supports ClientTrace.Got1xxResponse.
|
||||||
|
type ReverseProxy struct {
|
||||||
|
// Director is a function which modifies
|
||||||
|
// the request into a new request to be sent
|
||||||
|
// using Transport. Its response is then copied
|
||||||
|
// back to the original client unmodified.
|
||||||
|
// Director must not access the provided Request
|
||||||
|
// after returning.
|
||||||
|
//
|
||||||
|
// By default, the X-Forwarded-For header is set to the
|
||||||
|
// value of the client IP address. If an X-Forwarded-For
|
||||||
|
// header already exists, the client IP is appended to the
|
||||||
|
// existing values. As a special case, if the header
|
||||||
|
// exists in the Request.Header map but has a nil value
|
||||||
|
// (such as when set by the Director func), the X-Forwarded-For
|
||||||
|
// header is not modified.
|
||||||
|
//
|
||||||
|
// To prevent IP spoofing, be sure to delete any pre-existing
|
||||||
|
// X-Forwarded-For header coming from the client or
|
||||||
|
// an untrusted proxy.
|
||||||
|
//
|
||||||
|
// Hop-by-hop headers are removed from the request after
|
||||||
|
// Director returns, which can remove headers added by
|
||||||
|
// Director. Use a Rewrite function instead to ensure
|
||||||
|
// modifications to the request are preserved.
|
||||||
|
//
|
||||||
|
// Unparsable query parameters are removed from the outbound
|
||||||
|
// request if Request.Form is set after Director returns.
|
||||||
|
//
|
||||||
|
// At most one of Rewrite or Director may be set.
|
||||||
|
Director func(*http.Request)
|
||||||
|
|
||||||
|
// The transport used to perform proxy requests.
|
||||||
|
// If nil, http.DefaultTransport is used.
|
||||||
|
Transport http.RoundTripper
|
||||||
|
|
||||||
|
// ErrorLog specifies an optional logger for errors
|
||||||
|
// that occur when attempting to proxy the request.
|
||||||
|
// If nil, logging is done via the log package's standard logger.
|
||||||
|
ErrorLog *log.Logger
|
||||||
|
|
||||||
|
// ErrorHandler is an optional function that handles errors
|
||||||
|
// reaching the backend or errors from ModifyResponse.
|
||||||
|
//
|
||||||
|
// If nil, the default is to log the provided error and return
|
||||||
|
// a 502 Status Bad Gateway response.
|
||||||
|
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyHeader(dst, src http.Header) {
|
||||||
|
for k, vv := range src {
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
|
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||||
|
// Connection header field. These are the headers defined by the
|
||||||
|
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||||
|
// compatibility.
|
||||||
|
var hopHeaders = []string{
|
||||||
|
"Connection",
|
||||||
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"Te", // canonicalized version of "TE"
|
||||||
|
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||||
|
"Transfer-Encoding",
|
||||||
|
"Upgrade",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
p.logf("http: proxy error: %v", err)
|
||||||
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
|
||||||
|
if p.ErrorHandler != nil {
|
||||||
|
return p.ErrorHandler
|
||||||
|
}
|
||||||
|
return p.defaultErrorHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
transport := p.Transport
|
||||||
|
ctx := req.Context()
|
||||||
|
|
||||||
|
outreq := req.Clone(ctx)
|
||||||
|
if req.ContentLength == 0 {
|
||||||
|
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
|
||||||
|
}
|
||||||
|
if outreq.Body != nil {
|
||||||
|
// Reading from the request body after returning from a handler is not
|
||||||
|
// allowed, and the RoundTrip goroutine that reads the Body can outlive
|
||||||
|
// this handler. This can lead to a crash if the handler panics (see
|
||||||
|
// Issue 46866). Although calling Close doesn't guarantee there isn't
|
||||||
|
// any Read in flight after the handle returns, in practice it's safe to
|
||||||
|
// read after closing it.
|
||||||
|
defer outreq.Body.Close()
|
||||||
|
}
|
||||||
|
if outreq.Header == nil {
|
||||||
|
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Director(outreq)
|
||||||
|
outreq.Close = false
|
||||||
|
|
||||||
|
reqUpType := upgradeType(outreq.Header)
|
||||||
|
if !IsPrint(reqUpType) {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Del("Forwarded")
|
||||||
|
removeHopByHopHeaders(outreq.Header)
|
||||||
|
|
||||||
|
// Issue 21096: tell backend applications that care about trailer support
|
||||||
|
// that we support trailers. (We do, but we don't go out of our way to
|
||||||
|
// advertise that unless the incoming client request thought it was worth
|
||||||
|
// mentioning.) Note that we look at req.Header, not outreq.Header, since
|
||||||
|
// the latter has passed through removeHopByHopHeaders.
|
||||||
|
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||||
|
outreq.Header.Set("Te", "trailers")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := outreq.Header["User-Agent"]; !ok {
|
||||||
|
// If the outbound request doesn't have a User-Agent header set,
|
||||||
|
// don't send the default Go HTTP client User-Agent.
|
||||||
|
outreq.Header.Set("User-Agent", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
roundTripMutex sync.Mutex
|
||||||
|
roundTripDone bool
|
||||||
|
)
|
||||||
|
trace := &httptrace.ClientTrace{
|
||||||
|
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||||
|
roundTripMutex.Lock()
|
||||||
|
defer roundTripMutex.Unlock()
|
||||||
|
if roundTripDone {
|
||||||
|
// If RoundTrip has returned, don't try to further modify
|
||||||
|
// the ResponseWriter's header map.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h := rw.Header()
|
||||||
|
copyHeader(h, http.Header(header))
|
||||||
|
rw.WriteHeader(code)
|
||||||
|
|
||||||
|
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||||
|
clear(h)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||||
|
|
||||||
|
res, err := transport.RoundTrip(outreq)
|
||||||
|
roundTripMutex.Lock()
|
||||||
|
roundTripDone = true
|
||||||
|
roundTripMutex.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
p.getErrorHandler()(rw, outreq, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||||
|
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||||
|
p.handleUpgradeResponse(rw, outreq, res)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
removeHopByHopHeaders(res.Header)
|
||||||
|
|
||||||
|
copyHeader(rw.Header(), res.Header)
|
||||||
|
|
||||||
|
// The "Trailer" header isn't included in the Transport's response,
|
||||||
|
// at least for *http.Transport. Build it up from Trailer.
|
||||||
|
announcedTrailers := len(res.Trailer)
|
||||||
|
if announcedTrailers > 0 {
|
||||||
|
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||||
|
for k := range res.Trailer {
|
||||||
|
trailerKeys = append(trailerKeys, k)
|
||||||
|
}
|
||||||
|
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.WriteHeader(res.StatusCode)
|
||||||
|
|
||||||
|
err = utils.CopyCloseWithContext(ctx, rw, res.Body)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
p.getErrorHandler()(rw, req, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Trailer) > 0 {
|
||||||
|
// Force chunking if we saw a response trailer.
|
||||||
|
// This prevents net/http from calculating the length for short
|
||||||
|
// bodies and adding a Content-Length.
|
||||||
|
http.NewResponseController(rw).Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Trailer) == announcedTrailers {
|
||||||
|
copyHeader(rw.Header(), res.Trailer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, vv := range res.Trailer {
|
||||||
|
k = http.TrailerPrefix + k
|
||||||
|
for _, v := range vv {
|
||||||
|
rw.Header().Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeHopByHopHeaders removes hop-by-hop headers.
|
||||||
|
func removeHopByHopHeaders(h http.Header) {
|
||||||
|
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||||
|
for _, f := range h["Connection"] {
|
||||||
|
for sf := range strings.SplitSeq(f, ",") {
|
||||||
|
if sf = textproto.TrimString(sf); sf != "" {
|
||||||
|
h.Del(sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||||
|
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||||
|
// preserve it for backwards compatibility.
|
||||||
|
for _, f := range hopHeaders {
|
||||||
|
h.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) logf(format string, args ...any) {
|
||||||
|
if p.ErrorLog != nil {
|
||||||
|
p.ErrorLog.Printf(format, args...)
|
||||||
|
} else {
|
||||||
|
log.Printf(format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func upgradeType(h http.Header) string {
|
||||||
|
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return h.Get("Upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||||
|
reqUpType := upgradeType(req.Header)
|
||||||
|
resUpType := upgradeType(res.Header)
|
||||||
|
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(reqUpType, resUpType) {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backConn, ok := res.Body.(io.ReadWriteCloser)
|
||||||
|
if !ok {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := http.NewResponseController(rw)
|
||||||
|
conn, brw, hijackErr := rc.Hijack()
|
||||||
|
if errors.Is(hijackErr, http.ErrNotSupported) {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backConnCloseCh := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
// Ensure that the cancellation of a request closes the backend.
|
||||||
|
// See issue https://golang.org/issue/35559.
|
||||||
|
select {
|
||||||
|
case <-req.Context().Done():
|
||||||
|
case <-backConnCloseCh:
|
||||||
|
}
|
||||||
|
backConn.Close()
|
||||||
|
}()
|
||||||
|
defer close(backConnCloseCh)
|
||||||
|
|
||||||
|
if hijackErr != nil {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
copyHeader(rw.Header(), res.Header)
|
||||||
|
|
||||||
|
res.Header = rw.Header()
|
||||||
|
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
|
||||||
|
if err := res.Write(brw); err != nil {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := brw.Flush(); err != nil {
|
||||||
|
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errc := make(chan error, 1)
|
||||||
|
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
|
go spc.copyToBackend(errc)
|
||||||
|
go spc.copyFromBackend(errc)
|
||||||
|
<-errc
|
||||||
|
}
|
||||||
|
|
||||||
|
// switchProtocolCopier exists so goroutines proxying data back and
|
||||||
|
// forth have nice names in stacks.
|
||||||
|
type switchProtocolCopier struct {
|
||||||
|
user, backend io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||||
|
_, err := io.Copy(c.user, c.backend)
|
||||||
|
errc <- err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||||
|
_, err := io.Copy(c.backend, c.user)
|
||||||
|
errc <- err
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPrint(s string) bool {
|
||||||
|
for _, r := range s {
|
||||||
|
if r < ' ' || r > '~' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue