diff --git a/agent/go.mod b/agent/go.mod index 48ca2b1..f225e10 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -11,7 +11,7 @@ replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250425 replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250502022742-408a348f1b97 require ( - github.com/coder/websocket v1.8.13 + github.com/gorilla/websocket v1.5.3 github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.10.0 github.com/yusing/go-proxy v0.0.0-00010101000000-000000000000 @@ -45,7 +45,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect github.com/gorilla/mux v1.8.1 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/gotify/server/v2 v2.6.3 // indirect github.com/jinzhu/copier v0.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/agent/go.sum b/agent/go.sum index 233cb7e..4d0713c 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -10,8 +10,6 @@ github.com/buger/goterm v1.0.4 h1:Z9YvGmOih81P0FbVtEYTFF6YsSgxSUKEhf/f9bTMXbY= github.com/buger/goterm v1.0.4/go.mod h1:HiFWV3xnkolgrBV3mY8m0X0Pumt4zg4QhbdOzQtB8tE= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= -github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/agent/pkg/agent/requests.go b/agent/pkg/agent/requests.go index b9b810a..6dc655d 100644 --- a/agent/pkg/agent/requests.go +++ b/agent/pkg/agent/requests.go @@ -5,7 +5,7 @@ import ( "io" "net/http" - "github.com/coder/websocket" + "github.com/gorilla/websocket" ) func (cfg *AgentConfig) Do(ctx context.Context, method, endpoint string, body io.Reader) (*http.Response, error) { @@ -42,8 +42,12 @@ func (cfg *AgentConfig) Fetch(ctx context.Context, endpoint string) ([]byte, int } func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) { - return websocket.Dial(ctx, APIBaseURL+endpoint, &websocket.DialOptions{ - HTTPClient: cfg.NewHTTPClient(), - Host: AgentHost, + transport := cfg.Transport() + dialer := websocket.Dialer{ + NetDialContext: transport.DialContext, + NetDialTLSContext: transport.DialTLSContext, + } + return dialer.DialContext(ctx, APIBaseURL+endpoint, http.Header{ + "Host": {AgentHost}, }) } diff --git a/go.mod b/go.mod index 60d723e..a75c48c 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,8 @@ replace github.com/docker/docker => github.com/godoxy-app/docker v0.0.0-20250425 replace github.com/shirou/gopsutil/v4 => github.com/godoxy-app/gopsutil/v4 v4.0.0-20250502022742-408a348f1b97 require ( + github.com/gorilla/websocket v1.5.3 // websocket for API and agent github.com/PuerkitoBio/goquery v1.10.3 // parsing HTML for extract fav icon - github.com/coder/websocket v1.8.13 // websocket for API and agent github.com/coreos/go-oidc/v3 v3.14.1 // oidc authentication github.com/docker/docker v28.1.1+incompatible // docker daemon github.com/fsnotify/fsnotify v1.9.0 // file watcher @@ -121,7 +121,6 @@ require ( github.com/googleapis/gax-go/v2 v2.14.2 // indirect github.com/gophercloud/gophercloud v1.14.1 // indirect github.com/gophercloud/utils v0.0.0-20231010081019-80377eca5d56 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect diff --git a/go.sum b/go.sum index d638169..6a85653 100644 --- a/go.sum +++ b/go.sum @@ -768,8 +768,6 @@ github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20220314180256-7f1daf1720fc/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20230310173818-32f1caf87195/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= -github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= diff --git a/internal/api/v1/agents.go b/internal/api/v1/agents.go index d9d2a87..0579247 100644 --- a/internal/api/v1/agents.go +++ b/internal/api/v1/agents.go @@ -4,8 +4,7 @@ import ( "net/http" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" + "github.com/gorilla/websocket" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" @@ -15,8 +14,7 @@ import ( func ListAgents(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { if httpheaders.IsWebsocket(r.Header) { gpwebsocket.Periodic(w, r, 10*time.Second, func(conn *websocket.Conn) error { - wsjson.Write(r.Context(), conn, cfg.ListAgents()) - return nil + return conn.WriteJSON(cfg.ListAgents()) }) } else { gphttp.RespondJSON(w, r, cfg.ListAgents()) diff --git a/internal/api/v1/certapi/renew.go b/internal/api/v1/certapi/renew.go index bb993f9..b274ef4 100644 --- a/internal/api/v1/certapi/renew.go +++ b/internal/api/v1/certapi/renew.go @@ -22,8 +22,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - //nolint:errcheck - defer conn.CloseNow() + defer conn.Close() logs, cancel := memlogger.Events() defer cancel() @@ -35,7 +34,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) { err = autocert.ObtainCert() if err != nil { gperr.LogError("failed to obtain cert", err) - gpwebsocket.WriteText(r, conn, err.Error()) + _ = gpwebsocket.WriteText(conn, err.Error()) } else { logging.Info().Msg("cert obtained successfully") } @@ -46,7 +45,7 @@ func RenewCert(w http.ResponseWriter, r *http.Request) { if err != nil { return } - if !gpwebsocket.WriteText(r, conn, string(l)) { + if err := gpwebsocket.WriteText(conn, string(l)); err != nil { return } case <-done: diff --git a/internal/api/v1/dockerapi/logs.go b/internal/api/v1/dockerapi/logs.go index 6f2a6d3..385b59e 100644 --- a/internal/api/v1/dockerapi/logs.go +++ b/internal/api/v1/dockerapi/logs.go @@ -1,15 +1,18 @@ package dockerapi import ( + "context" + "errors" "net/http" "strconv" - "github.com/coder/websocket" "github.com/docker/docker/api/types/container" "github.com/docker/docker/pkg/stdcopy" + "github.com/gorilla/websocket" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" + "github.com/yusing/go-proxy/internal/task" ) func Logs(w http.ResponseWriter, r *http.Request) { @@ -31,6 +34,7 @@ func Logs(w http.ResponseWriter, r *http.Request) { gphttp.NotFound(w, "server not found") return } + defer dockerClient.Close() opts := container.LogsOptions{ ShowStdout: stdout, @@ -56,11 +60,14 @@ func Logs(w http.ResponseWriter, r *http.Request) { if err != nil { return } - defer conn.CloseNow() + defer conn.Close() - writer := gpwebsocket.NewWriter(r.Context(), conn, websocket.MessageText) + writer := gpwebsocket.NewWriter(r.Context(), conn, websocket.TextMessage) _, err = stdcopy.StdCopy(writer, writer, logs) // de-multiplex logs if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, task.ErrProgramExiting) { + return + } logging.Err(err). Str("server", server). Str("container", containerID). diff --git a/internal/api/v1/dockerapi/utils.go b/internal/api/v1/dockerapi/utils.go index 03f2484..6cfd326 100644 --- a/internal/api/v1/dockerapi/utils.go +++ b/internal/api/v1/dockerapi/utils.go @@ -6,8 +6,7 @@ import ( "net/http" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" + "github.com/gorilla/websocket" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/gperr" @@ -65,10 +64,12 @@ func getDockerClient(server string) (*docker.SharedClient, bool, error) { break } } - for _, agent := range cfg.ListAgents() { - if agent.Name() == server { - host = agent.FakeDockerHost() - break + if host == "" { + for _, agent := range cfg.ListAgents() { + if agent.Name() == server { + host = agent.FakeDockerHost() + break + } } } if host == "" { @@ -115,7 +116,7 @@ func serveHTTP[V any, T ResultType[V]](w http.ResponseWriter, r *http.Request, g if err != nil { return err } - return wsjson.Write(r.Context(), conn, result) + return conn.WriteJSON(result) }) } else { result, err := getResult(r.Context(), dockerClients) diff --git a/internal/api/v1/health.go b/internal/api/v1/health.go index 2528c3f..ef10fe1 100644 --- a/internal/api/v1/health.go +++ b/internal/api/v1/health.go @@ -4,8 +4,7 @@ import ( "net/http" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" + "github.com/gorilla/websocket" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" @@ -15,7 +14,7 @@ import ( func Health(w http.ResponseWriter, r *http.Request) { if httpheaders.IsWebsocket(r.Header) { gpwebsocket.Periodic(w, r, 1*time.Second, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, routes.HealthMap()) + return conn.WriteJSON(routes.HealthMap()) }) } else { gphttp.RespondJSON(w, r, routes.HealthMap()) diff --git a/internal/api/v1/list_route_providers.go b/internal/api/v1/list_route_providers.go index 2b373e9..341c6cd 100644 --- a/internal/api/v1/list_route_providers.go +++ b/internal/api/v1/list_route_providers.go @@ -4,8 +4,7 @@ import ( "net/http" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" + "github.com/gorilla/websocket" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" @@ -15,7 +14,7 @@ import ( func ListRouteProvidersHandler(cfgInstance config.ConfigInstance, w http.ResponseWriter, r *http.Request) { if httpheaders.IsWebsocket(r.Header) { gpwebsocket.Periodic(w, r, 5*time.Second, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, cfgInstance.RouteProviderList()) + return conn.WriteJSON(cfgInstance.RouteProviderList()) }) } else { gphttp.RespondJSON(w, r, cfgInstance.RouteProviderList()) diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index 1fbac51..0539c31 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -4,8 +4,7 @@ import ( "net/http" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" + "github.com/gorilla/websocket" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" @@ -16,7 +15,7 @@ import ( func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { if httpheaders.IsWebsocket(r.Header) { gpwebsocket.Periodic(w, r, 1*time.Second, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, getStats(cfg)) + return conn.WriteJSON(getStats(cfg)) }) } else { gphttp.RespondJSON(w, r, getStats(cfg)) diff --git a/internal/logging/memlogger/mem_logger.go b/internal/logging/memlogger/mem_logger.go index 60e6cd3..a6645f9 100644 --- a/internal/logging/memlogger/mem_logger.go +++ b/internal/logging/memlogger/mem_logger.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/coder/websocket" + "github.com/gorilla/websocket" "github.com/puzpuzpuz/xsync/v4" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" ) @@ -81,7 +81,7 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.connChans.Store(logCh, struct{}{}) defer func() { - _ = conn.CloseNow() + _ = conn.Close() m.notifyLock.Lock() m.connChans.Delete(logCh) @@ -89,7 +89,7 @@ func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.notifyLock.Unlock() }() - if err := m.wsInitial(r.Context(), conn); err != nil { + if err := m.wsInitial(conn); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -169,15 +169,16 @@ func (m *memLogger) events() (logs <-chan []byte, cancel func()) { } } -func (m *memLogger) writeBytes(ctx context.Context, conn *websocket.Conn, b []byte) error { - return conn.Write(ctx, websocket.MessageText, b) +func (m *memLogger) writeBytes(conn *websocket.Conn, b []byte) error { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return conn.WriteMessage(websocket.TextMessage, b) } -func (m *memLogger) wsInitial(ctx context.Context, conn *websocket.Conn) error { +func (m *memLogger) wsInitial(conn *websocket.Conn) error { m.Lock() defer m.Unlock() - return m.writeBytes(ctx, conn, m.Bytes()) + return m.writeBytes(conn, m.Bytes()) } func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) { @@ -188,7 +189,7 @@ func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <- case logRange := <-ch: m.RLock() msg := m.Bytes()[logRange.Start:logRange.End] - err := m.writeBytes(ctx, conn, msg) + err := m.writeBytes(conn, msg) m.RUnlock() if err != nil { return diff --git a/internal/metrics/period/handler.go b/internal/metrics/period/handler.go index e2d4e1b..a93a977 100644 --- a/internal/metrics/period/handler.go +++ b/internal/metrics/period/handler.go @@ -5,8 +5,7 @@ import ( "net/http" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" + "github.com/gorilla/websocket" metricsutils "github.com/yusing/go-proxy/internal/metrics/utils" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" @@ -45,7 +44,7 @@ func (p *Poller[T, AggregateT]) ServeHTTP(w http.ResponseWriter, r *http.Request if data == nil { return nil } - return wsjson.Write(r.Context(), conn, data) + return conn.WriteJSON(data) }) } else { data, err := p.getRespData(r) diff --git a/internal/net/gphttp/gpwebsocket/utils.go b/internal/net/gphttp/gpwebsocket/utils.go index e8f9135..0a0d229 100644 --- a/internal/net/gphttp/gpwebsocket/utils.go +++ b/internal/net/gphttp/gpwebsocket/utils.go @@ -1,11 +1,14 @@ package gpwebsocket import ( + "net" "net/http" + "slices" + "strings" "sync" "time" - "github.com/coder/websocket" + "github.com/gorilla/websocket" "github.com/yusing/go-proxy/internal/logging" ) @@ -27,29 +30,41 @@ func SetWebsocketAllowedDomains(h http.Header, domains []string) { h[HeaderXGoDoxyWebsocketAllowedDomains] = domains } -func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { - var originPats []string +var localAddresses = []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} +const writeTimeout = time.Second * 10 + +func Initiate(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { + upgrader := websocket.Upgrader{} allowedDomains := WebsocketAllowedDomains(r.Header) if len(allowedDomains) == 0 { warnNoMatchDomainOnce.Do(warnNoMatchDomains) - originPats = []string{"*"} - } else { - originPats = make([]string, len(allowedDomains)) - for i, domain := range allowedDomains { - if domain[0] != '.' { - originPats[i] = "*." + domain - } else { - originPats[i] = "*" + domain - } + upgrader.CheckOrigin = func(r *http.Request) bool { + return true + } + } else { + upgrader.CheckOrigin = func(r *http.Request) bool { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + if slices.Contains(localAddresses, host) { + return true + } + for _, domain := range allowedDomains { + if domain[0] == '.' { + if host == domain[1:] || strings.HasSuffix(host, domain) { + return true + } + } else if host == domain || strings.HasSuffix(host, "."+domain) { + return true + } + } + return false } - originPats = append(originPats, localAddresses...) } - return websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: originPats, - }) + return upgrader.Upgrade(w, r, nil) } func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { @@ -58,8 +73,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do http.Error(w, err.Error(), http.StatusInternalServerError) return } - //nolint:errcheck - defer conn.CloseNow() + defer conn.Close() if err := do(conn); err != nil { return @@ -73,6 +87,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do case <-r.Context().Done(): return case <-ticker.C: + _ = conn.SetWriteDeadline(time.Now().Add(writeTimeout)) if err := do(conn); err != nil { return } @@ -83,10 +98,7 @@ func Periodic(w http.ResponseWriter, r *http.Request, interval time.Duration, do // WriteText writes a text message to the websocket connection. // It returns true if the message was written successfully, false otherwise. // It logs an error if the message is not written successfully. -func WriteText(r *http.Request, conn *websocket.Conn, msg string) bool { - if err := conn.Write(r.Context(), websocket.MessageText, []byte(msg)); err != nil { - logging.Err(err).Msg("failed to write text message") - return false - } - return true +func WriteText(conn *websocket.Conn, msg string) error { + _ = conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + return conn.WriteMessage(websocket.TextMessage, []byte(msg)) } diff --git a/internal/net/gphttp/gpwebsocket/writer.go b/internal/net/gphttp/gpwebsocket/writer.go index 3e3998f..b47cf16 100644 --- a/internal/net/gphttp/gpwebsocket/writer.go +++ b/internal/net/gphttp/gpwebsocket/writer.go @@ -3,16 +3,16 @@ package gpwebsocket import ( "context" - "github.com/coder/websocket" + "github.com/gorilla/websocket" ) type Writer struct { conn *websocket.Conn - msgType websocket.MessageType + msgType int ctx context.Context } -func NewWriter(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) *Writer { +func NewWriter(ctx context.Context, conn *websocket.Conn, msgType int) *Writer { return &Writer{ ctx: ctx, conn: conn, @@ -21,9 +21,10 @@ func NewWriter(ctx context.Context, conn *websocket.Conn, msgType websocket.Mess } func (w *Writer) Write(p []byte) (int, error) { - return len(p), w.conn.Write(w.ctx, w.msgType, p) -} - -func (w *Writer) Close() error { - return w.conn.CloseNow() + select { + case <-w.ctx.Done(): + return 0, w.ctx.Err() + default: + return len(p), w.conn.WriteMessage(w.msgType, p) + } }