feat(socket-proxy): implement Docker socket proxy and related configurations

- Updated Dockerfile and Makefile for socket-proxy build.
- Modified go.mod to include necessary dependencies.
- Updated CI workflows for socket-proxy integration.
- Better module isolation
- Code refactor
This commit is contained in:
yusing 2025-05-10 09:47:03 +08:00
parent 4ddfb48b9d
commit 8fe94d6d14
38 changed files with 658 additions and 523 deletions

View file

@ -15,9 +15,10 @@ jobs:
with:
image_name: ${{ github.repository_owner }}/godoxy
tag: nightly
target: main
build-nightly-agent:
uses: ./.github/workflows/docker-image.yml
with:
image_name: ${{ github.repository_owner }}/godoxy-agent
tag: nightly
agent: true
target: agent

View file

@ -12,9 +12,10 @@ jobs:
image_name: ${{ github.repository_owner }}/godoxy
old_image_name: ${{ github.repository_owner }}/go-proxy
tag: latest
target: main
build-prod-agent:
uses: ./.github/workflows/docker-image.yml
with:
image_name: ${{ github.repository_owner }}/godoxy-agent
tag: latest
agent: true
target: agent

View file

@ -0,0 +1,15 @@
name: Docker Image CI (socket-proxy)
on:
push:
paths:
- "socket-proxy/**"
jobs:
build:
uses: ./.github/workflows/docker-image.yml
with:
image_name: ${{ github.repository_owner }}/socket-proxy
tag: latest
target: socket-proxy
dockerfile: socket-proxy.Dockerfile

View file

@ -12,16 +12,20 @@ on:
old_image_name:
required: false
type: string
agent:
target:
required: true
type: string
dockerfile:
required: false
default: false
type: boolean
type: string
default: Dockerfile
env:
REGISTRY: ghcr.io
MAKE_ARGS: agent=${{ inputs.agent && '1' || '0' }}
DIGEST_PATH: /tmp/digests/${{ inputs.agent && 'agent' || 'main' }}
DIGEST_NAME_SUFFIX: ${{ inputs.agent && 'agent' || 'main' }}
MAKE_ARGS: ${{ inputs.target }}=1
DIGEST_PATH: /tmp/digests/${{ inputs.target }}
DIGEST_NAME_SUFFIX: ${{ inputs.target }}
DOCKERFILE: ${{ inputs.dockerfile }}
jobs:
build:
@ -76,6 +80,7 @@ jobs:
with:
platforms: ${{ matrix.platform }}
labels: ${{ steps.meta.outputs.labels }}
file: ${{ env.DOCKERFILE }}
outputs: type=image,name=${{ env.REGISTRY }}/${{ inputs.image_name }},push-by-digest=true,name-canonical=true,push=true
cache-from: |
type=registry,ref=${{ env.REGISTRY }}/${{ inputs.image_name }}:buildcache-${{ env.PLATFORM_PAIR }}-${{ inputs.tag }}

View file

@ -8,11 +8,12 @@ LDFLAGS = -X github.com/yusing/go-proxy/pkg.version=${VERSION}
ifeq ($(agent), 1)
NAME = godoxy-agent
CMD_PATH = ./cmd
PWD = ${shell pwd}/agent
else ifeq ($(socket-proxy), 1)
NAME = godoxy-socket-proxy
PWD = ${shell pwd}/socket-proxy
else
NAME = godoxy
CMD_PATH = ./cmd
PWD = ${shell pwd}
endif
@ -46,7 +47,6 @@ BUILD_FLAGS += -ldflags='$(LDFLAGS)'
BIN_PATH := $(shell pwd)/bin/${NAME}
export NAME
export CMD_PATH
export CGO_ENABLED
export GODOXY_DEBUG
export GODOXY_TRACE
@ -97,13 +97,19 @@ update-deps:
cd ${PWD}/$$path && go get -u ./... && go mod tidy; \
done
mod-tidy:
for path in ${gomod_paths}; do \
echo "go mod tidy $$path"; \
cd ${PWD}/$$path && go mod tidy; \
done
build:
mkdir -p $(shell dirname ${BIN_PATH})
cd ${PWD} && go build ${BUILD_FLAGS} -o ${BIN_PATH} ${CMD_PATH}
cd ${PWD} && go build ${BUILD_FLAGS} -o ${BIN_PATH} ./cmd
${POST_BUILD}
run:
[ -f .env ] && godotenv -f .env go run ${BUILD_FLAGS} ${CMD_PATH}
cd ${PWD} && [ -f .env ] && godotenv -f .env go run ${BUILD_FLAGS} ./cmd
debug:
make NAME="godoxy-test" debug=1 build
@ -125,7 +131,7 @@ ci-test:
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
cloc:
cloc --not-match-f '_test.go$$' cmd internal pkg
cloc --include-lang=Go --not-match-f '_test.go$$' .
push-github:
git push origin $(shell git rev-parse --abbrev-ref HEAD)

View file

@ -1,24 +1,19 @@
package main
import (
"os"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/handler"
"github.com/yusing/go-proxy/agent/pkg/server"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
httpServer "github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg"
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
)
func main() {
logging.InitLogger(os.Stderr, memlogger.GetMemLogger())
ca := &agent.PEMPair{}
err := ca.Load(env.AgentCACert)
if err != nil {
@ -58,12 +53,12 @@ Tips:
server.StartAgentServer(t, opts)
if env.DockerSocketAddr != "" {
logging.Info().Msgf("Docker socket listening on: %s", env.DockerSocketAddr)
if socketproxy.ListenAddr != "" {
logging.Info().Msgf("Docker socket listening on: %s", socketproxy.ListenAddr)
opts := httpServer.Options{
Name: "docker",
HTTPAddr: env.DockerSocketAddr,
Handler: handler.NewDockerHandler(),
HTTPAddr: socketproxy.ListenAddr,
Handler: socketproxy.NewHandler(),
}
httpServer.StartServer(t, opts)
}

View file

@ -4,19 +4,20 @@ go 1.24.3
replace github.com/yusing/go-proxy => ..
require (
github.com/coder/websocket v1.8.13
github.com/docker/docker v28.1.1+incompatible
github.com/gorilla/mux v1.8.1
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy v0.12.3
)
replace github.com/yusing/go-proxy/socketproxy => ../socket-proxy
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250425105916-b2ad800de7a1
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250502022742-408a348f1b97
require (
github.com/coder/websocket v1.8.13
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
github.com/yusing/go-proxy/socketproxy v0.0.0-00010101000000-000000000000
)
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/PuerkitoBio/goquery v1.10.3 // indirect
@ -28,6 +29,7 @@ require (
github.com/distribution/reference v0.6.0 // indirect
github.com/djherbis/times v1.6.0 // indirect
github.com/docker/cli v28.1.1+incompatible // indirect
github.com/docker/docker v28.1.1+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.2 // indirect
@ -42,6 +44,7 @@ require (
github.com/goccy/go-yaml v1.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/gotify/server/v2 v2.6.3 // indirect
github.com/jinzhu/copier v0.4.0 // indirect
@ -74,8 +77,6 @@ require (
github.com/tklauser/numcpus v0.10.0 // indirect
github.com/vincent-petithory/dataurl v1.0.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/mock v0.5.2 // indirect

View file

@ -200,12 +200,12 @@ go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 h1:1fTNlAIJZGWLP5FVu0fikVry1IsiUnXjf7QFvoNN3Xw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0/go.mod h1:zjPK58DtkqQFn+YUMbx0M2XV3QgKU0gS9LeGohREyK4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 h1:xJ2qHD0C1BeYVTLLR9sX12+Qb95kfeD/byKj6Ky1pXg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0/go.mod h1:u5BF1xyjstDowA1R5QAO9JHzqK+ublenEW/dyqTjBVk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.31.0 h1:lUsI2TYsQw2r1IASwoROaCnjdj2cvC2+Jbxvk6nHnWU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.31.0/go.mod h1:2HpZxxQurfGxJlJDblybejHB6RX6pmExPNe517hREw4=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk=
go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=

View file

@ -5,6 +5,8 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
@ -14,10 +16,7 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg"
)
@ -80,7 +79,7 @@ func (cfg *AgentConfig) Parse(addr string) error {
return nil
}
func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) error {
func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) error {
clientCert, err := tls.X509KeyPair(crt, key)
if err != nil {
return err
@ -90,7 +89,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(ca)
if !ok {
return gperr.New("invalid ca certificate")
return errors.New("invalid ca certificate")
}
cfg.tlsConfig = &tls.Config{
@ -102,7 +101,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
// create transport and http client
cfg.httpClient = cfg.NewHTTPClient()
ctx, cancel := context.WithTimeout(parent.Context(), 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// get agent name
@ -131,23 +130,23 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
return nil
}
func (cfg *AgentConfig) Start(parent task.Parent) gperr.Error {
func (cfg *AgentConfig) Start(ctx context.Context) error {
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
if !ok {
return gperr.New("invalid agent host").Subject(cfg.Addr)
return fmt.Errorf("invalid agent host: %s", cfg.Addr)
}
certData, err := os.ReadFile(filepath)
if err != nil {
return gperr.Wrap(err, "failed to read agent certs")
return fmt.Errorf("failed to read agent certs: %w", err)
}
ca, crt, key, err := certs.ExtractCert(certData)
if err != nil {
return gperr.Wrap(err, "failed to extract agent certs")
return fmt.Errorf("failed to extract agent certs: %w", err)
}
return gperr.Wrap(cfg.StartWithCerts(parent, ca, crt, key))
return cfg.StartWithCerts(ctx, ca, crt, key)
}
func (cfg *AgentConfig) NewHTTPClient() *http.Client {
@ -171,8 +170,10 @@ func (cfg *AgentConfig) Transport() *http.Transport {
}
}
var dialer = &net.Dialer{Timeout: 5 * time.Second}
func (cfg *AgentConfig) DialContext(ctx context.Context) (net.Conn, error) {
return gphttp.DefaultDialer.DialContext(ctx, "tcp", cfg.Addr)
return dialer.DialContext(ctx, "tcp", cfg.Addr)
}
func (cfg *AgentConfig) Name() string {

View file

@ -8,59 +8,59 @@ import (
"net/http/httptest"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
"github.com/stretchr/testify/require"
)
func TestNewAgent(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
ExpectTrue(t, ca != nil)
ExpectTrue(t, srv != nil)
ExpectTrue(t, client != nil)
require.NoError(t, err)
require.NotNil(t, ca)
require.NotNil(t, srv)
require.NotNil(t, client)
}
func TestPEMPair(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
require.NoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) {
var pp PEMPair
err := pp.Load(p.String())
ExpectNoError(t, err)
ExpectEqual(t, p.Cert, pp.Cert)
ExpectEqual(t, p.Key, pp.Key)
require.NoError(t, err)
require.Equal(t, p.Cert, pp.Cert)
require.Equal(t, p.Key, pp.Key)
})
}
}
func TestPEMPairToTLSCert(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
require.NoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) {
cert, err := p.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, cert != nil)
require.NoError(t, err)
require.NotNil(t, cert)
})
}
}
func TestServerClient(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
require.NoError(t, err)
srvTLS, err := srv.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, srvTLS != nil)
require.NoError(t, err)
require.NotNil(t, srvTLS)
clientTLS, err := client.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, clientTLS != nil)
require.NoError(t, err)
require.NotNil(t, clientTLS)
caPool := x509.NewCertPool()
ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert))
require.True(t, caPool.AppendCertsFromPEM(ca.Cert))
srvTLSConfig := &tls.Config{
Certificates: []tls.Certificate{*srvTLS},
@ -86,6 +86,6 @@ func TestServerClient(t *testing.T) {
}
resp, err := httpClient.Get(server.URL)
ExpectNoError(t, err)
ExpectEqual(t, resp.StatusCode, http.StatusOK)
require.NoError(t, err)
require.Equal(t, resp.StatusCode, http.StatusOK)
}

60
agent/pkg/env/env.go vendored
View file

@ -20,35 +20,6 @@ var (
AgentSkipClientCertCheck bool
AgentCACert string
AgentSSLCert string
DockerSocketAddr string
DockerPost bool
DockerRestarts bool
DockerStart bool
DockerStop bool
DockerAuth bool
DockerBuild bool
DockerCommit bool
DockerConfigs bool
DockerContainers bool
DockerDistribution bool
DockerEvents bool
DockerExec bool
DockerGrpc bool
DockerImages bool
DockerInfo bool
DockerNetworks bool
DockerNodes bool
DockerPing bool
DockerPlugins bool
DockerSecrets bool
DockerServices bool
DockerSession bool
DockerSwarm bool
DockerSystem bool
DockerTasks bool
DockerVersion bool
DockerVolumes bool
)
func init() {
@ -62,35 +33,4 @@ func Load() {
AgentCACert = common.GetEnvString("AGENT_CA_CERT", "")
AgentSSLCert = common.GetEnvString("AGENT_SSL_CERT", "")
// docker socket proxy
DockerSocketAddr = common.GetEnvString("DOCKER_SOCKET_ADDR", "127.0.0.1:2375")
DockerPost = common.GetEnvBool("POST", false)
DockerRestarts = common.GetEnvBool("ALLOW_RESTARTS", false)
DockerStart = common.GetEnvBool("ALLOW_START", false)
DockerStop = common.GetEnvBool("ALLOW_STOP", false)
DockerAuth = common.GetEnvBool("AUTH", false)
DockerBuild = common.GetEnvBool("BUILD", false)
DockerCommit = common.GetEnvBool("COMMIT", false)
DockerConfigs = common.GetEnvBool("CONFIGS", false)
DockerContainers = common.GetEnvBool("CONTAINERS", false)
DockerDistribution = common.GetEnvBool("DISTRIBUTION", false)
DockerEvents = common.GetEnvBool("EVENTS", true)
DockerExec = common.GetEnvBool("EXEC", false)
DockerGrpc = common.GetEnvBool("GRPC", false)
DockerImages = common.GetEnvBool("IMAGES", false)
DockerInfo = common.GetEnvBool("INFO", false)
DockerNetworks = common.GetEnvBool("NETWORKS", false)
DockerNodes = common.GetEnvBool("NODES", false)
DockerPing = common.GetEnvBool("PING", true)
DockerPlugins = common.GetEnvBool("PLUGINS", false)
DockerSecrets = common.GetEnvBool("SECRETS", false)
DockerServices = common.GetEnvBool("SERVICES", false)
DockerSession = common.GetEnvBool("SESSION", false)
DockerSwarm = common.GetEnvBool("SWARM", false)
DockerSystem = common.GetEnvBool("SYSTEM", false)
DockerTasks = common.GetEnvBool("TASKS", false)
DockerVersion = common.GetEnvBool("VERSION", true)
DockerVolumes = common.GetEnvBool("VOLUMES", false)
}

View file

@ -1,13 +1,13 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/watcher/health"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
@ -73,5 +73,7 @@ func CheckHealth(w http.ResponseWriter, r *http.Request) {
return
}
gphttp.RespondJSON(w, r, result)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(result)
}

View file

@ -1,38 +0,0 @@
package handler
import (
"net/http"
"net/url"
"github.com/docker/docker/client"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
)
func serviceUnavailable(w http.ResponseWriter, r *http.Request) {
http.Error(w, "docker socket is not available", http.StatusServiceUnavailable)
}
func mockDockerSocketHandler() http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("mock docker response"))
})
}
func DockerSocketHandler() http.HandlerFunc {
dockerClient, err := docker.NewClient(common.DockerHostFromEnv)
if err != nil {
logging.Warn().Err(err).Msg("failed to connect to docker client")
return serviceUnavailable
}
rp := reverseproxy.NewReverseProxy("docker", types.NewURL(&url.URL{
Scheme: "http",
Host: client.DummyHost,
}), dockerClient.HTTPClient().Transport)
return rp.ServeHTTP
}

View file

@ -2,201 +2,35 @@ package handler
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/pkg"
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
)
type ServeMux struct{ *http.ServeMux }
func (mux ServeMux) HandleMethods(methods, endpoint string, handler http.HandlerFunc) {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+agent.APIEndpointBase+endpoint, handler)
}
func (mux ServeMux) HandleEndpoint(method, endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(method+" "+agent.APIEndpointBase+endpoint, handler)
}
func (mux ServeMux) HandleFunc(endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(agent.APIEndpointBase+endpoint, handler)
}
type NopWriteCloser struct {
io.Writer
}
func (NopWriteCloser) Close() error {
return nil
}
func NewAgentHandler() http.Handler {
mux := ServeMux{http.NewServeMux()}
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
mux.HandleMethods("GET", agent.EndpointVersion, pkg.GetVersionHTTPHandler())
mux.HandleMethods("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
mux.HandleEndpoint("GET", agent.EndpointVersion, pkg.GetVersionHTTPHandler())
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, env.AgentName)
})
mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth)
mux.HandleMethods("GET", agent.EndpointLogs, memlogger.HandlerFunc())
mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", DockerSocketHandler())
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", socketproxy.DockerSocketHandler())
return mux
}
func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "Endpoint not allowed", http.StatusForbidden)
}
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
func NewDockerHandler() http.Handler {
r := mux.NewRouter()
var socketHandler http.HandlerFunc
if common.IsTest {
socketHandler = mockDockerSocketHandler()
} else {
socketHandler = DockerSocketHandler()
}
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
allowedPaths := []string{}
deniedPaths := []string{}
if env.DockerContainers {
allowedPaths = append(allowedPaths, "/containers")
if !env.DockerRestarts {
deniedPaths = append(deniedPaths, containerPath+"/stop")
deniedPaths = append(deniedPaths, containerPath+"/restart")
deniedPaths = append(deniedPaths, containerPath+"/kill")
}
if !env.DockerStart {
deniedPaths = append(deniedPaths, containerPath+"/start")
}
if !env.DockerStop && env.DockerRestarts {
deniedPaths = append(deniedPaths, containerPath+"/stop")
}
}
if env.DockerAuth {
allowedPaths = append(allowedPaths, "/auth")
}
if env.DockerBuild {
allowedPaths = append(allowedPaths, "/build")
}
if env.DockerCommit {
allowedPaths = append(allowedPaths, "/commit")
}
if env.DockerConfigs {
allowedPaths = append(allowedPaths, "/configs")
}
if env.DockerDistribution {
allowedPaths = append(allowedPaths, "/distribution")
}
if env.DockerEvents {
allowedPaths = append(allowedPaths, "/events")
}
if env.DockerExec {
allowedPaths = append(allowedPaths, "/exec")
}
if env.DockerGrpc {
allowedPaths = append(allowedPaths, "/grpc")
}
if env.DockerImages {
allowedPaths = append(allowedPaths, "/images")
}
if env.DockerInfo {
allowedPaths = append(allowedPaths, "/info")
}
if env.DockerNetworks {
allowedPaths = append(allowedPaths, "/networks")
}
if env.DockerNodes {
allowedPaths = append(allowedPaths, "/nodes")
}
if env.DockerPing {
allowedPaths = append(allowedPaths, "/_ping")
}
if env.DockerPlugins {
allowedPaths = append(allowedPaths, "/plugins")
}
if env.DockerSecrets {
allowedPaths = append(allowedPaths, "/secrets")
}
if env.DockerServices {
allowedPaths = append(allowedPaths, "/services")
}
if env.DockerSession {
allowedPaths = append(allowedPaths, "/session")
}
if env.DockerSwarm {
allowedPaths = append(allowedPaths, "/swarm")
}
if env.DockerSystem {
allowedPaths = append(allowedPaths, "/system")
}
if env.DockerTasks {
allowedPaths = append(allowedPaths, "/tasks")
}
if env.DockerVersion {
allowedPaths = append(allowedPaths, "/version")
}
if env.DockerVolumes {
allowedPaths = append(allowedPaths, "/volumes")
}
// Helper to determine if a path should be treated as a prefix
isPrefixPath := func(path string) bool {
return strings.Count(path, "/") == 1
}
// 1. Register Denied Paths (specific)
for _, path := range deniedPaths {
// Handle with version prefix
r.HandleFunc(apiVersionPrefix+path, endpointNotAllowed)
// Handle without version prefix
r.HandleFunc(path, endpointNotAllowed)
}
// 2. Register Allowed Paths
for _, p := range allowedPaths {
fullPathWithVersion := apiVersionPrefix + p
if isPrefixPath(p) {
r.PathPrefix(fullPathWithVersion).Handler(socketHandler)
r.PathPrefix(p).Handler(socketHandler)
} else {
r.HandleFunc(fullPathWithVersion, socketHandler)
r.HandleFunc(p, socketHandler)
}
}
// 3. Add fallback for any other routes
r.PathPrefix("/").HandlerFunc(endpointNotAllowed)
// HTTP method filtering
if !env.DockerPost {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodGet:
r.ServeHTTP(w, req)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodPost, http.MethodGet:
r.ServeHTTP(w, req)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}

View file

@ -3,18 +3,26 @@ package handler
import (
"crypto/tls"
"net/http"
"net/url"
"net/http/httputil"
"strconv"
"time"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/agentproxy"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/reverseproxy"
"github.com/yusing/go-proxy/internal/net/types"
)
func NewTransport() *http.Transport {
return &http.Transport{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 60 * time.Second,
WriteBufferSize: 16 * 1024, // 16KB
ReadBufferSize: 16 * 1024, // 16KB
}
}
func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
host := r.Header.Get(agentproxy.HeaderXProxyHost)
isHTTPS, _ := strconv.ParseBool(r.Header.Get(agentproxy.HeaderXProxyHTTPS))
@ -34,11 +42,9 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
scheme = "https"
}
var transport *http.Transport
transport := NewTransport()
if skipTLSVerify {
transport = gphttp.NewTransportWithTLSConfig(&tls.Config{InsecureSkipVerify: true})
} else {
transport = gphttp.NewTransport()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if responseHeaderTimeout > 0 {
@ -49,14 +55,13 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Host = ""
r.URL.Path = r.URL.Path[agent.HTTPProxyURLPrefixLen:] // strip the {API_BASE}/proxy/http prefix
r.RequestURI = r.URL.String()
r.URL.Host = host
rp := &httputil.ReverseProxy{
Director: func(r *http.Request) {
r.URL.Scheme = scheme
logging.Debug().Msgf("proxy http request: %s %s", r.Method, r.URL.String())
rp := reverseproxy.NewReverseProxy("agent", types.NewURL(&url.URL{
Scheme: scheme,
Host: host,
}), transport)
r.URL.Host = host
},
Transport: transport,
}
rp.ServeHTTP(w, r)
}

View file

@ -2,7 +2,7 @@
services:
socket-proxy:
container_name: socket-proxy
image: lscr.io/linuxserver/socket-proxy:latest
image: ghcr.io/yusing/socket-proxy:latest
environment:
- ALLOW_START=1
- ALLOW_STOP=1

35
go.mod
View file

@ -6,6 +6,12 @@ replace github.com/yusing/go-proxy/agent => ./agent
replace github.com/yusing/go-proxy/internal/dnsproviders => ./internal/dnsproviders
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250425105916-b2ad800de7a1
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250502022742-408a348f1b97
require (
github.com/PuerkitoBio/goquery v1.10.3 // parsing HTML for extract fav icon
github.com/coder/websocket v1.8.13 // websocket for API and agent
@ -24,16 +30,12 @@ require (
golang.org/x/crypto v0.38.0 // encrypting password with bcrypt
golang.org/x/net v0.40.0 // HTTP header utilities
golang.org/x/oauth2 v0.30.0 // oauth2 authentication
golang.org/x/text v0.25.0 // string utilities
golang.org/x/time v0.11.0 // time utilities
gopkg.in/yaml.v3 v3.0.1 // indirect; yaml parsing for different config files
)
replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2
require (
github.com/docker/cli v28.1.1+incompatible
github.com/goccy/go-yaml v1.17.1
github.com/goccy/go-yaml v1.17.1 // yaml parsing for different config files
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/luthermonson/go-proxmox v0.2.2
github.com/oschwald/maxminddb-golang v1.13.1
@ -41,21 +43,11 @@ require (
github.com/samber/slog-zerolog/v2 v2.7.3
github.com/spf13/afero v1.14.0
github.com/stretchr/testify v1.10.0
github.com/yusing/go-proxy/agent v0.0.0-20250508094936-75ee0e63bd7d
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-20250508094936-75ee0e63bd7d
github.com/yusing/go-proxy/agent v0.0.0-20250509063132-4ddfb48b9d0b
github.com/yusing/go-proxy/internal/dnsproviders v0.0.0-20250509063132-4ddfb48b9d0b
go.uber.org/atomic v1.11.0
)
require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 // indirect
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
)
replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250425105916-b2ad800de7a1
replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250502022742-408a348f1b97
require (
cloud.google.com/go/auth v0.16.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
@ -130,6 +122,7 @@ require (
github.com/gophercloud/gophercloud v1.14.1 // indirect
github.com/gophercloud/utils v0.0.0-20231010081019-80377eca5d56 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.7 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
@ -209,7 +202,7 @@ require (
github.com/spf13/pflag v1.0.6 // indirect
github.com/spf13/viper v1.20.1 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
github.com/tklauser/go-sysconf v0.3.15 // indirect
@ -217,7 +210,7 @@ require (
github.com/transip/gotransip/v6 v6.26.0 // indirect
github.com/ultradns/ultradns-go-sdk v1.8.0-20241010134910-243eeec // indirect
github.com/vinyldns/go-vinyldns v0.9.16 // indirect
github.com/volcengine/volc-sdk-golang v1.0.206 // indirect
github.com/volcengine/volc-sdk-golang v1.0.207 // indirect
github.com/vultr/govultr/v3 v3.20.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
@ -226,8 +219,10 @@ require (
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 // indirect
go.opentelemetry.io/otel/metric v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/mock v0.5.2 // indirect
go.uber.org/multierr v1.11.0 // indirect
@ -235,6 +230,7 @@ require (
golang.org/x/mod v0.24.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.25.0
golang.org/x/tools v0.33.0 // indirect
google.golang.org/api v0.232.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250422160041-2d3770c4ea7f // indirect
@ -245,6 +241,7 @@ require (
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/ns1/ns1-go.v2 v2.14.3 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/api v0.33.0 // indirect
k8s.io/apimachinery v0.33.0 // indirect
k8s.io/klog/v2 v2.130.1 // indirect

12
go.sum
View file

@ -1615,8 +1615,8 @@ github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNG
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1136/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160 h1:jKVzMJy52E0zGbabQiZ7KaaYJwwwWblZAKgkt0Mex5E=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161 h1:S4dJSWhOtaPjp0/GO/yhzUC6DfZvpWhrnsEKaLxr73c=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136 h1:kMIdSU5IvpOROh27ToVQ3hlm6ym3lCRs9tnGCOBoZqk=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136/go.mod h1:FpyIz3mymKaExVs6Fz27kxDBS42jqZn7vbACtxdeEH4=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
@ -1642,8 +1642,8 @@ github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8A
github.com/vincent-petithory/dataurl v1.0.0/go.mod h1:FHafX5vmDzyP+1CQATJn7WFKc9CvnvxyvZy6I1MrG/U=
github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ=
github.com/vinyldns/go-vinyldns v0.9.16/go.mod h1:5qIJOdmzAnatKjurI+Tl4uTus7GJKJxb+zitufjHs3Q=
github.com/volcengine/volc-sdk-golang v1.0.206 h1:7NG8FCpvu9wbx+Z4I/p3tcTS2zdBqTZtJXgydunGy6g=
github.com/volcengine/volc-sdk-golang v1.0.206/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/volcengine/volc-sdk-golang v1.0.207 h1:1OJ/nC92dF1URRoyO1AHSghCob12NT1PAA/GoK8uU18=
github.com/volcengine/volc-sdk-golang v1.0.207/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/vultr/govultr/v3 v3.20.0 h1:O+Om6gXpN6ehwAIIKq5DyGuekpyHaoRlwrxTb44bDzA=
github.com/vultr/govultr/v3 v3.20.0/go.mod h1:q34Wd76upKmf+vxFMgaNMH3A8BbsPBmSYZUGC8oZa5w=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
@ -1701,8 +1701,8 @@ go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0 h1:1fTNlAIJZGWLP5FVu0fikVry1IsiUnXjf7QFvoNN3Xw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.35.0/go.mod h1:zjPK58DtkqQFn+YUMbx0M2XV3QgKU0gS9LeGohREyK4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0 h1:xJ2qHD0C1BeYVTLLR9sX12+Qb95kfeD/byKj6Ky1pXg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.35.0/go.mod h1:u5BF1xyjstDowA1R5QAO9JHzqK+ublenEW/dyqTjBVk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.31.0 h1:lUsI2TYsQw2r1IASwoROaCnjdj2cvC2+Jbxvk6nHnWU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.31.0/go.mod h1:2HpZxxQurfGxJlJDblybejHB6RX6pmExPNe517hREw4=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=

View file

@ -12,6 +12,7 @@ import (
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/uptime"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils"
"github.com/yusing/go-proxy/pkg"
@ -45,7 +46,7 @@ func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...b
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
gpwebsocket.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}

View file

@ -3,8 +3,7 @@ package common
import (
"crypto/rand"
"encoding/base64"
"github.com/rs/zerolog/log"
"log"
)
func decodeJWTKey(key string) []byte {
@ -13,7 +12,7 @@ func decodeJWTKey(key string) []byte {
}
bytes, err := base64.StdEncoding.DecodeString(key)
if err != nil {
log.Fatal().Str("key", key).Err(err).Msg("failed to decode secret")
log.Fatalf("failed to decode secret: %s", err)
}
return bytes
}
@ -22,7 +21,7 @@ func RandomJWTKey() []byte {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
log.Fatal().Err(err).Msg("failed to generate random jwt key")
log.Fatalf("failed to generate random jwt key: %s", err)
}
return key
}

View file

@ -2,13 +2,13 @@ package common
import (
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
@ -78,7 +78,7 @@ func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T
if err == nil {
return parsed
}
log.Fatal().Err(err).Msgf("env %s: invalid %T value: %s", key, parsed, value)
log.Fatalf("env %s: invalid %T value: %s", key, parsed, value)
return defaultValue
}
@ -105,7 +105,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host string, portInt in
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Fatal().Msgf("env %s: invalid address: %s", key, addr)
log.Fatalf("env %s: invalid address: %s", key, addr)
}
if host == "" {
host = "localhost"
@ -113,7 +113,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host string, portInt in
fullURL = fmt.Sprintf("%s://%s:%s", scheme, host, port)
portInt, err = strconv.Atoi(port)
if err != nil {
log.Fatal().Msgf("env %s: invalid port: %s", key, port)
log.Fatalf("env %s: invalid port: %s", key, port)
}
return
}

View file

@ -40,7 +40,7 @@ func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PE
var agentCfg agent.AgentConfig
agentCfg.Addr = host
err := agentCfg.StartWithCerts(cfg.Task(), ca.Cert, client.Cert, client.Key)
err := agentCfg.StartWithCerts(cfg.Task().Context(), ca.Cert, client.Cert, client.Key)
if err != nil {
return 0, gperr.Wrap(err, "failed to start agent")
}

View file

@ -328,8 +328,8 @@ func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
removeAllAgents()
for _, agent := range providers.Agents {
if err := agent.Start(cfg.task); err != nil {
errs.Add(err.Subject(agent.String()))
if err := agent.Start(cfg.task.Context()); err != nil {
errs.Add(gperr.PrependSubject(agent.String(), err))
continue
}
addAgent(agent)

View file

@ -6,7 +6,7 @@ replace github.com/yusing/go-proxy => ../..
require (
github.com/go-acme/lego/v4 v4.23.1
github.com/yusing/go-proxy v0.12.3
github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000
)
require (
@ -146,13 +146,13 @@ require (
github.com/spf13/viper v1.20.1 // indirect
github.com/stretchr/testify v1.10.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
github.com/transip/gotransip/v6 v6.26.0 // indirect
github.com/ultradns/ultradns-go-sdk v1.8.0-20241010134910-243eeec // indirect
github.com/vinyldns/go-vinyldns v0.9.16 // indirect
github.com/volcengine/volc-sdk-golang v1.0.206 // indirect
github.com/volcengine/volc-sdk-golang v1.0.207 // indirect
github.com/vultr/govultr/v3 v3.20.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect

View file

@ -1519,8 +1519,8 @@ github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNG
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1136/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160 h1:jKVzMJy52E0zGbabQiZ7KaaYJwwwWblZAKgkt0Mex5E=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1160/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161 h1:S4dJSWhOtaPjp0/GO/yhzUC6DfZvpWhrnsEKaLxr73c=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1161/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136 h1:kMIdSU5IvpOROh27ToVQ3hlm6ym3lCRs9tnGCOBoZqk=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.1136/go.mod h1:FpyIz3mymKaExVs6Fz27kxDBS42jqZn7vbACtxdeEH4=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
@ -1538,8 +1538,8 @@ github.com/ultradns/ultradns-go-sdk v1.8.0-20241010134910-243eeec/go.mod h1:BZr7
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/vinyldns/go-vinyldns v0.9.16 h1:GZJStDkcCk1F1AcRc64LuuMh+ENL8pHA0CVd4ulRMcQ=
github.com/vinyldns/go-vinyldns v0.9.16/go.mod h1:5qIJOdmzAnatKjurI+Tl4uTus7GJKJxb+zitufjHs3Q=
github.com/volcengine/volc-sdk-golang v1.0.206 h1:7NG8FCpvu9wbx+Z4I/p3tcTS2zdBqTZtJXgydunGy6g=
github.com/volcengine/volc-sdk-golang v1.0.206/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/volcengine/volc-sdk-golang v1.0.207 h1:1OJ/nC92dF1URRoyO1AHSghCob12NT1PAA/GoK8uU18=
github.com/volcengine/volc-sdk-golang v1.0.207/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
github.com/vultr/govultr/v3 v3.20.0 h1:O+Om6gXpN6ehwAIIKq5DyGuekpyHaoRlwrxTb44bDzA=
github.com/vultr/govultr/v3 v3.20.0/go.mod h1:q34Wd76upKmf+vxFMgaNMH3A8BbsPBmSYZUGC8oZa5w=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=

View file

@ -3,8 +3,8 @@ package accesslog_test
import (
"testing"
"github.com/yusing/go-proxy/internal/docker"
. "github.com/yusing/go-proxy/internal/logging/accesslog"
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/utils"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)

View file

@ -9,9 +9,8 @@ import (
"time"
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/puzpuzpuz/xsync/v3"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type logEntryRange struct {
@ -22,8 +21,8 @@ type memLogger struct {
*bytes.Buffer
sync.RWMutex
notifyLock sync.RWMutex
connChans F.Map[chan *logEntryRange, struct{}]
listeners F.Map[chan []byte, struct{}]
connChans *xsync.MapOf[chan *logEntryRange, struct{}]
listeners *xsync.MapOf[chan []byte, struct{}]
}
type MemLogger io.Writer
@ -40,8 +39,8 @@ const (
var memLoggerInstance = &memLogger{
Buffer: bytes.NewBuffer(make([]byte, maxMemLogSize)),
connChans: F.NewMapOf[chan *logEntryRange, struct{}](),
listeners: F.NewMapOf[chan []byte, struct{}](),
connChans: xsync.NewMapOf[chan *logEntryRange, struct{}](),
listeners: xsync.NewMapOf[chan []byte, struct{}](),
}
func GetMemLogger() MemLogger {
@ -136,7 +135,7 @@ func (m *memLogger) Write(p []byte) (n int, err error) {
func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := gpwebsocket.Initiate(w, r)
if err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@ -153,7 +152,7 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}()
if err := m.wsInitial(r.Context(), conn); err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

View file

@ -6,11 +6,7 @@ import (
"time"
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
)
func warnNoMatchDomains() {
@ -19,13 +15,25 @@ func warnNoMatchDomains() {
var warnNoMatchDomainOnce sync.Once
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 || common.IsDebug {
allowedDomains := WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 {
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"}
} else {
@ -47,14 +55,14 @@ func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := Initiate(w, r)
if err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
//nolint:errcheck
defer conn.CloseNow()
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@ -67,7 +75,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
return
case <-ticker.C:
if err := do(conn); err != nil {
gphttp.ServerError(w, r, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
@ -79,7 +87,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do
// It logs an error if the message is not written successfully.
func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool {
if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil {
gperr.LogError("failed to write text message", err)
logging.Err(err).Msg("failed to write text message")
return false
}
return true

View file

@ -4,18 +4,6 @@ import (
"net/http"
)
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func IsWebsocket(h http.Header) bool {
return UpgradeType(h) == "websocket"
}

View file

@ -4,12 +4,21 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
. "github.com/yusing/go-proxy/internal/utils/strutils"
expect "github.com/yusing/go-proxy/internal/utils/testing"
)
func mustParseTime(t *testing.T, layout, value string) time.Time {
t.Helper()
time, err := time.Parse(layout, value)
if err != nil {
t.Fatalf("failed to parse time: %s", err)
}
return time
}
func TestFormatTime(t *testing.T) {
now := expect.Must(time.Parse(time.RFC3339, "2021-06-15T12:30:30Z"))
now := mustParseTime(t, time.RFC3339, "2021-06-15T12:30:30Z")
tests := []struct {
name string
@ -84,9 +93,9 @@ func TestFormatTime(t *testing.T) {
result := FormatTimeWithReference(tt.time, now)
if tt.expectedLength > 0 {
expect.Equal(t, len(result), tt.expectedLength, result)
require.Equal(t, tt.expectedLength, len(result), result)
} else {
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
}
})
}
@ -173,7 +182,7 @@ func TestFormatDuration(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatDuration(tt.duration)
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
})
}
}
@ -203,7 +212,7 @@ func TestFormatLastSeen(t *testing.T) {
result := FormatLastSeen(tt.time)
if tt.name == "zero time" {
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
} else {
// Just make sure it's not "never", the actual formatting is tested in TestFormatTime
if result == "never" {
@ -290,7 +299,7 @@ func TestFormatByteSize(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatByteSize(tt.size)
expect.Equal(t, result, tt.expected)
require.Equal(t, tt.expected, result)
})
}
}

View file

@ -4,8 +4,8 @@ import (
"strings"
"testing"
"github.com/stretchr/testify/require"
. "github.com/yusing/go-proxy/internal/utils/strutils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
var alphaNumeric = func() string {
@ -31,8 +31,8 @@ func TestSplit(t *testing.T) {
for sep, rsep := range tests {
t.Run(sep, func(t *testing.T) {
expected := strings.Split(alphaNumeric, sep)
ExpectEqual(t, SplitRune(alphaNumeric, rsep), expected)
ExpectEqual(t, JoinRune(expected, rsep), alphaNumeric)
require.Equal(t, expected, SplitRune(alphaNumeric, rsep))
require.Equal(t, alphaNumeric, JoinRune(expected, rsep))
})
}
}

50
socket-proxy.Dockerfile Normal file
View file

@ -0,0 +1,50 @@
# Stage 1: deps
FROM golang:1.24.3-alpine AS deps
HEALTHCHECK NONE
# package version does not matter
# trunk-ignore(hadolint/DL3018)
RUN apk add --no-cache tzdata make libcap-setcap
ENV GOPATH=/root/go
WORKDIR /src
COPY socket-proxy/go.mod socket-proxy/go.sum ./
RUN go mod download -x
# Stage 2: builder
FROM deps AS builder
WORKDIR /src
COPY Makefile ./
COPY socket-proxy ./socket-proxy
ARG VERSION
ENV VERSION=${VERSION}
ARG MAKE_ARGS
ENV MAKE_ARGS=${MAKE_ARGS}
ENV GOCACHE=/root/.cache/go-build
ENV GOPATH=/root/go
RUN make ${MAKE_ARGS} docker=1 build
# Stage 3: Final image
FROM scratch
LABEL maintainer="yusing@6uo.me"
LABEL proxy.exclude=1
# copy timezone data
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
# copy binary
COPY --from=builder /app/run /app/run
WORKDIR /app
CMD ["/app/run"]

16
socket-proxy/cmd/main.go Normal file
View file

@ -0,0 +1,16 @@
package main
import (
"log"
"net/http"
socketproxy "github.com/yusing/go-proxy/socketproxy/pkg"
)
func main() {
if socketproxy.ListenAddr == "" {
log.Fatal("Docker socket address is not set")
}
log.Printf("Docker socket listening on: %s", socketproxy.ListenAddr)
http.ListenAndServe(socketproxy.ListenAddr, socketproxy.NewHandler())
}

5
socket-proxy/go.mod Normal file
View file

@ -0,0 +1,5 @@
module github.com/yusing/go-proxy/socketproxy
go 1.24.3
require github.com/gorilla/mux v1.8.1

2
socket-proxy/go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=

101
socket-proxy/pkg/env.go Normal file
View file

@ -0,0 +1,101 @@
package socketproxy
import (
"log"
"os"
"strconv"
)
var (
DockerSocket,
ListenAddr string
DockerPost,
DockerRestarts,
DockerStart,
DockerStop,
DockerAuth,
DockerBuild,
DockerCommit,
DockerConfigs,
DockerContainers,
DockerDistribution,
DockerEvents,
DockerExec,
DockerGrpc,
DockerImages,
DockerInfo,
DockerNetworks,
DockerNodes,
DockerPing,
DockerPlugins,
DockerSecrets,
DockerServices,
DockerSession,
DockerSwarm,
DockerSystem,
DockerTasks,
DockerVersion,
DockerVolumes bool
)
func init() {
Load()
}
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
value, ok := os.LookupEnv(key)
if !ok || value == "" {
return defaultValue
}
parsed, err := parser(value)
if err != nil {
log.Fatalf("env %s: invalid %T value: %s", key, parsed, value)
}
return parsed
}
func GetEnvString(key string, defaultValue string) string {
return GetEnv(key, defaultValue, stringstring)
}
func GetEnvBool(key string, defaultValue bool) bool {
return GetEnv(key, defaultValue, strconv.ParseBool)
}
func stringstring(s string) (string, error) {
return s, nil
}
func Load() {
DockerSocket = GetEnvString("DOCKER_SOCKET", GetEnvString("DOCKER_HOST", "/var/run/docker.sock"))
ListenAddr = GetEnvString("LISTEN_ADDR", GetEnvString("DOCKER_SOCKET_ADDR", "")) // default to disabled
DockerPost = GetEnvBool("POST", false)
DockerRestarts = GetEnvBool("ALLOW_RESTARTS", false)
DockerStart = GetEnvBool("ALLOW_START", false)
DockerStop = GetEnvBool("ALLOW_STOP", false)
DockerAuth = GetEnvBool("AUTH", false)
DockerBuild = GetEnvBool("BUILD", false)
DockerCommit = GetEnvBool("COMMIT", false)
DockerConfigs = GetEnvBool("CONFIGS", false)
DockerContainers = GetEnvBool("CONTAINERS", false)
DockerDistribution = GetEnvBool("DISTRIBUTION", false)
DockerEvents = GetEnvBool("EVENTS", true)
DockerExec = GetEnvBool("EXEC", false)
DockerGrpc = GetEnvBool("GRPC", false)
DockerImages = GetEnvBool("IMAGES", false)
DockerInfo = GetEnvBool("INFO", false)
DockerNetworks = GetEnvBool("NETWORKS", false)
DockerNodes = GetEnvBool("NODES", false)
DockerPing = GetEnvBool("PING", true)
DockerPlugins = GetEnvBool("PLUGINS", false)
DockerSecrets = GetEnvBool("SECRETS", false)
DockerServices = GetEnvBool("SERVICES", false)
DockerSession = GetEnvBool("SESSION", false)
DockerSwarm = GetEnvBool("SWARM", false)
DockerSystem = GetEnvBool("SYSTEM", false)
DockerTasks = GetEnvBool("TASKS", false)
DockerVersion = GetEnvBool("VERSION", true)
DockerVolumes = GetEnvBool("VOLUMES", false)
}

179
socket-proxy/pkg/handler.go Normal file
View file

@ -0,0 +1,179 @@
package socketproxy
import (
"context"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/gorilla/mux"
"net/url"
)
var dialer = &net.Dialer{KeepAlive: 1 * time.Second}
func dialDockerSocket(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", DockerSocket)
}
var DockerSocketHandler = dockerSocketHandler
func dockerSocketHandler() http.HandlerFunc {
rp := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: "api.moby.localhost",
})
rp.Transport = &http.Transport{
DialContext: dialDockerSocket,
}
return rp.ServeHTTP
}
func endpointNotAllowed(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "Endpoint not allowed", http.StatusForbidden)
}
// ref: https://github.com/Tecnativa/docker-socket-proxy/blob/master/haproxy.cfg
func NewHandler() http.Handler {
r := mux.NewRouter()
socketHandler := DockerSocketHandler()
const apiVersionPrefix = `/{version:(?:v[\d\.]+)?}`
const containerPath = "/containers/{id:[a-zA-Z0-9_.-]+}"
allowedPaths := []string{}
deniedPaths := []string{}
if DockerContainers {
allowedPaths = append(allowedPaths, "/containers")
if !DockerRestarts {
deniedPaths = append(deniedPaths, containerPath+"/stop")
deniedPaths = append(deniedPaths, containerPath+"/restart")
deniedPaths = append(deniedPaths, containerPath+"/kill")
}
if !DockerStart {
deniedPaths = append(deniedPaths, containerPath+"/start")
}
if !DockerStop && DockerRestarts {
deniedPaths = append(deniedPaths, containerPath+"/stop")
}
}
if DockerAuth {
allowedPaths = append(allowedPaths, "/auth")
}
if DockerBuild {
allowedPaths = append(allowedPaths, "/build")
}
if DockerCommit {
allowedPaths = append(allowedPaths, "/commit")
}
if DockerConfigs {
allowedPaths = append(allowedPaths, "/configs")
}
if DockerDistribution {
allowedPaths = append(allowedPaths, "/distribution")
}
if DockerEvents {
allowedPaths = append(allowedPaths, "/events")
}
if DockerExec {
allowedPaths = append(allowedPaths, "/exec")
}
if DockerGrpc {
allowedPaths = append(allowedPaths, "/grpc")
}
if DockerImages {
allowedPaths = append(allowedPaths, "/images")
}
if DockerInfo {
allowedPaths = append(allowedPaths, "/info")
}
if DockerNetworks {
allowedPaths = append(allowedPaths, "/networks")
}
if DockerNodes {
allowedPaths = append(allowedPaths, "/nodes")
}
if DockerPing {
allowedPaths = append(allowedPaths, "/_ping")
}
if DockerPlugins {
allowedPaths = append(allowedPaths, "/plugins")
}
if DockerSecrets {
allowedPaths = append(allowedPaths, "/secrets")
}
if DockerServices {
allowedPaths = append(allowedPaths, "/services")
}
if DockerSession {
allowedPaths = append(allowedPaths, "/session")
}
if DockerSwarm {
allowedPaths = append(allowedPaths, "/swarm")
}
if DockerSystem {
allowedPaths = append(allowedPaths, "/system")
}
if DockerTasks {
allowedPaths = append(allowedPaths, "/tasks")
}
if DockerVersion {
allowedPaths = append(allowedPaths, "/version")
}
if DockerVolumes {
allowedPaths = append(allowedPaths, "/volumes")
}
// Helper to determine if a path should be treated as a prefix
isPrefixPath := func(path string) bool {
return strings.Count(path, "/") == 1
}
// 1. Register Denied Paths (specific)
for _, path := range deniedPaths {
// Handle with version prefix
r.HandleFunc(apiVersionPrefix+path, endpointNotAllowed)
// Handle without version prefix
r.HandleFunc(path, endpointNotAllowed)
}
// 2. Register Allowed Paths
for _, p := range allowedPaths {
fullPathWithVersion := apiVersionPrefix + p
if isPrefixPath(p) {
r.PathPrefix(fullPathWithVersion).Handler(socketHandler)
r.PathPrefix(p).Handler(socketHandler)
} else {
r.HandleFunc(fullPathWithVersion, socketHandler)
r.HandleFunc(p, socketHandler)
}
}
// 3. Add fallback for any other routes
r.PathPrefix("/").HandlerFunc(endpointNotAllowed)
// HTTP method filtering
if !DockerPost {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodGet:
r.ServeHTTP(w, req)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case http.MethodPost, http.MethodGet:
r.ServeHTTP(w, req)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}

View file

@ -1,13 +1,26 @@
package handler
package socketproxy_test
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/yusing/go-proxy/agent/pkg/env"
. "github.com/yusing/go-proxy/socketproxy/pkg"
)
func mockDockerSocketHandler() http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("mock docker response"))
})
}
func TestMain(m *testing.M) {
DockerSocketHandler = mockDockerSocketHandler
os.Exit(m.Run())
}
func TestNewDockerHandler(t *testing.T) {
tests := []struct {
name string
@ -35,7 +48,7 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodGet,
path: "/containers",
envSetup: func() {
env.DockerContainers = true
DockerContainers = true
},
wantStatusCode: http.StatusOK,
},
@ -44,7 +57,7 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodGet,
path: "/containers",
envSetup: func() {
env.DockerContainers = false
DockerContainers = false
},
wantStatusCode: http.StatusForbidden,
},
@ -53,7 +66,7 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/_ping",
envSetup: func() {
env.DockerPost = false
DockerPost = false
},
wantStatusCode: http.StatusMethodNotAllowed,
},
@ -62,8 +75,8 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/_ping",
envSetup: func() {
env.DockerPost = true
env.DockerPing = true
DockerPost = true
DockerPing = true
},
wantStatusCode: http.StatusOK,
},
@ -72,9 +85,9 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/containers/test-container/restart",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerRestarts = false
DockerPost = true
DockerContainers = true
DockerRestarts = false
},
wantStatusCode: http.StatusForbidden,
},
@ -83,9 +96,9 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/containers/test-container/restart",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerRestarts = true
DockerPost = true
DockerContainers = true
DockerRestarts = true
},
wantStatusCode: http.StatusOK,
},
@ -94,9 +107,9 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/containers/test-container/start",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStart = false
DockerPost = true
DockerContainers = true
DockerStart = false
},
wantStatusCode: http.StatusForbidden,
},
@ -105,9 +118,9 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/containers/test-container/start",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStart = true
DockerPost = true
DockerContainers = true
DockerStart = true
},
wantStatusCode: http.StatusOK,
},
@ -116,9 +129,9 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/containers/test-container/stop",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStop = false
DockerPost = true
DockerContainers = true
DockerStop = false
},
wantStatusCode: http.StatusForbidden,
},
@ -127,9 +140,9 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPost,
path: "/containers/test-container/stop",
envSetup: func() {
env.DockerPost = true
env.DockerContainers = true
env.DockerStop = true
DockerPost = true
DockerContainers = true
DockerStop = true
},
wantStatusCode: http.StatusOK,
},
@ -138,7 +151,7 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodGet,
path: "/v1.41/version",
envSetup: func() {
env.DockerVersion = true
DockerVersion = true
},
wantStatusCode: http.StatusOK,
},
@ -147,7 +160,7 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodPut,
path: "/version",
envSetup: func() {
env.DockerVersion = true
DockerVersion = true
},
wantStatusCode: http.StatusMethodNotAllowed,
},
@ -156,30 +169,30 @@ func TestNewDockerHandler(t *testing.T) {
method: http.MethodDelete,
path: "/version",
envSetup: func() {
env.DockerVersion = true
DockerVersion = true
},
wantStatusCode: http.StatusMethodNotAllowed,
},
}
// Save original env values to restore after tests
originalContainers := env.DockerContainers
originalRestarts := env.DockerRestarts
originalStart := env.DockerStart
originalStop := env.DockerStop
originalPost := env.DockerPost
originalPing := env.DockerPing
originalVersion := env.DockerVersion
originalContainers := DockerContainers
originalRestarts := DockerRestarts
originalStart := DockerStart
originalStop := DockerStop
originalPost := DockerPost
originalPing := DockerPing
originalVersion := DockerVersion
defer func() {
// Restore original values
env.DockerContainers = originalContainers
env.DockerRestarts = originalRestarts
env.DockerStart = originalStart
env.DockerStop = originalStop
env.DockerPost = originalPost
env.DockerPing = originalPing
env.DockerVersion = originalVersion
DockerContainers = originalContainers
DockerRestarts = originalRestarts
DockerStart = originalStart
DockerStop = originalStop
DockerPost = originalPost
DockerPing = originalPing
DockerVersion = originalVersion
}()
for _, tt := range tests {
@ -188,7 +201,7 @@ func TestNewDockerHandler(t *testing.T) {
tt.envSetup()
// Create test handler that will record the response for verification
dockerHandler := NewDockerHandler()
dockerHandler := NewHandler()
// Test server to capture the response
recorder := httptest.NewRecorder()
@ -291,73 +304,73 @@ func TestNewDockerHandler_PathHandling(t *testing.T) {
defer func() {
// Restore original env values
env.Load()
Load()
}()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset all Docker* env vars to false for this test
env.Load()
Load()
// Enable POST for all these tests
env.DockerPost = true
DockerPost = true
// Set the specific env var for this test
switch tt.envVarName {
case "DockerContainers":
env.DockerContainers = tt.envVarValue
DockerContainers = tt.envVarValue
case "DockerRestarts":
env.DockerRestarts = tt.envVarValue
DockerRestarts = tt.envVarValue
case "DockerStart":
env.DockerStart = tt.envVarValue
DockerStart = tt.envVarValue
case "DockerStop":
env.DockerStop = tt.envVarValue
DockerStop = tt.envVarValue
case "DockerAuth":
env.DockerAuth = tt.envVarValue
DockerAuth = tt.envVarValue
case "DockerBuild":
env.DockerBuild = tt.envVarValue
DockerBuild = tt.envVarValue
case "DockerCommit":
env.DockerCommit = tt.envVarValue
DockerCommit = tt.envVarValue
case "DockerConfigs":
env.DockerConfigs = tt.envVarValue
DockerConfigs = tt.envVarValue
case "DockerDistribution":
env.DockerDistribution = tt.envVarValue
DockerDistribution = tt.envVarValue
case "DockerEvents":
env.DockerEvents = tt.envVarValue
DockerEvents = tt.envVarValue
case "DockerExec":
env.DockerExec = tt.envVarValue
DockerExec = tt.envVarValue
case "DockerGrpc":
env.DockerGrpc = tt.envVarValue
DockerGrpc = tt.envVarValue
case "DockerImages":
env.DockerImages = tt.envVarValue
DockerImages = tt.envVarValue
case "DockerInfo":
env.DockerInfo = tt.envVarValue
DockerInfo = tt.envVarValue
case "DockerNetworks":
env.DockerNetworks = tt.envVarValue
DockerNetworks = tt.envVarValue
case "DockerNodes":
env.DockerNodes = tt.envVarValue
DockerNodes = tt.envVarValue
case "DockerPlugins":
env.DockerPlugins = tt.envVarValue
DockerPlugins = tt.envVarValue
case "DockerSecrets":
env.DockerSecrets = tt.envVarValue
DockerSecrets = tt.envVarValue
case "DockerServices":
env.DockerServices = tt.envVarValue
DockerServices = tt.envVarValue
case "DockerSession":
env.DockerSession = tt.envVarValue
DockerSession = tt.envVarValue
case "DockerSwarm":
env.DockerSwarm = tt.envVarValue
DockerSwarm = tt.envVarValue
case "DockerSystem":
env.DockerSystem = tt.envVarValue
DockerSystem = tt.envVarValue
case "DockerTasks":
env.DockerTasks = tt.envVarValue
DockerTasks = tt.envVarValue
case "DockerVolumes":
env.DockerVolumes = tt.envVarValue
DockerVolumes = tt.envVarValue
default:
t.Fatalf("Unknown env var: %s", tt.envVarName)
}
// Create test handler
dockerHandler := NewDockerHandler()
dockerHandler := NewHandler()
// Test server to capture the response
recorder := httptest.NewRecorder()
@ -385,11 +398,11 @@ func TestNewDockerHandler_PathHandling(t *testing.T) {
// This is a more comprehensive test that verifies the full request/response chain
func TestNewDockerHandlerWithMockDocker(t *testing.T) {
// Set up environment
env.DockerContainers = true
env.DockerPost = true
DockerContainers = true
DockerPost = true
// Create the handler
handler := NewDockerHandler()
handler := NewHandler()
// Test a valid request
req, _ := http.NewRequest(http.MethodGet, "/containers", nil)
@ -401,8 +414,8 @@ func TestNewDockerHandlerWithMockDocker(t *testing.T) {
}
// Test a disallowed path
env.DockerContainers = false
handler = NewDockerHandler() // recreate with new env
DockerContainers = false
handler = NewHandler() // recreate with new env
req, _ = http.NewRequest(http.MethodGet, "/containers", nil)
recorder = httptest.NewRecorder()