From fe7740f1b0165a3779be9a577e706c63539a07d1 Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 19 Jan 2025 04:33:55 +0800 Subject: [PATCH] api: cleanup websocket code --- internal/api/v1/stats.go | 43 ++--------------------- internal/api/v1/utils/ws.go | 68 +++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 41 deletions(-) create mode 100644 internal/api/v1/utils/ws.go diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index 0d9617b..e86c8de 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -1,14 +1,12 @@ package v1 import ( - "context" "net/http" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -18,46 +16,9 @@ func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { } func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - var originPats []string - - localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} - - if len(cfg.Value().MatchDomains) == 0 { - U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins") - originPats = []string{"*"} - } else { - originPats = make([]string, len(cfg.Value().MatchDomains)) - for i, domain := range cfg.Value().MatchDomains { - originPats[i] = "*" + domain - } - originPats = append(originPats, localAddresses...) - } - if common.IsDebug { - originPats = []string{"*"} - } - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: originPats, + U.PeriodicWS(cfg, w, r, 1*time.Second, func(conn *websocket.Conn) error { + return wsjson.Write(r.Context(), conn, getStats(cfg)) }) - if err != nil { - U.LogError(r).Err(err).Msg("failed to upgrade websocket") - return - } - /* trunk-ignore(golangci-lint/errcheck) */ - defer conn.CloseNow() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for range ticker.C { - stats := getStats(cfg) - if err := wsjson.Write(ctx, conn, stats); err != nil { - U.LogError(r).Msg("failed to write JSON") - return - } - } } var startTime = time.Now() diff --git a/internal/api/v1/utils/ws.go b/internal/api/v1/utils/ws.go new file mode 100644 index 0000000..28db66b --- /dev/null +++ b/internal/api/v1/utils/ws.go @@ -0,0 +1,68 @@ +package utils + +import ( + "net/http" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/yusing/go-proxy/internal/common" + config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/logging" +) + +func warnNoMatchDomains() { + logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins") +} + +var warnNoMatchDomainOnce sync.Once + +func InitiateWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { + var originPats []string + + localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} + + if len(cfg.Value().MatchDomains) == 0 { + warnNoMatchDomainOnce.Do(warnNoMatchDomains) + originPats = []string{"*"} + } else { + originPats = make([]string, len(cfg.Value().MatchDomains)) + for i, domain := range cfg.Value().MatchDomains { + originPats[i] = "*" + domain + } + originPats = append(originPats, localAddresses...) + } + if common.IsDebug { + originPats = []string{"*"} + } + return websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: originPats, + }) +} + +func PeriodicWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { + conn, err := InitiateWS(cfg, w, r) + if err != nil { + HandleErr(w, r, err) + return + } + /* trunk-ignore(golangci-lint/errcheck) */ + defer conn.CloseNow() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-cfg.Context().Done(): + return + case <-r.Context().Done(): + return + case <-ticker.C: + if err := do(conn); err != nil { + HandleErr(w, r, err) + return + } + } + } +}