mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-09 13:02:33 +02:00
api: cleanup websocket code
This commit is contained in:
parent
b253dce7e1
commit
fe7740f1b0
2 changed files with 70 additions and 41 deletions
|
@ -1,14 +1,12 @@
|
||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
"github.com/coder/websocket/wsjson"
|
"github.com/coder/websocket/wsjson"
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
@ -18,46 +16,9 @@ func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
|
||||||
var originPats []string
|
U.PeriodicWS(cfg, w, r, 1*time.Second, func(conn *websocket.Conn) error {
|
||||||
|
return wsjson.Write(r.Context(), conn, getStats(cfg))
|
||||||
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
|
||||||
|
|
||||||
if len(cfg.Value().MatchDomains) == 0 {
|
|
||||||
U.LogWarn(r).Msg("no match domains configured, accepting websocket API request from all origins")
|
|
||||||
originPats = []string{"*"}
|
|
||||||
} else {
|
|
||||||
originPats = make([]string, len(cfg.Value().MatchDomains))
|
|
||||||
for i, domain := range cfg.Value().MatchDomains {
|
|
||||||
originPats[i] = "*" + domain
|
|
||||||
}
|
|
||||||
originPats = append(originPats, localAddresses...)
|
|
||||||
}
|
|
||||||
if common.IsDebug {
|
|
||||||
originPats = []string{"*"}
|
|
||||||
}
|
|
||||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
||||||
OriginPatterns: originPats,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
U.LogError(r).Err(err).Msg("failed to upgrade websocket")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
/* trunk-ignore(golangci-lint/errcheck) */
|
|
||||||
defer conn.CloseNow()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
stats := getStats(cfg)
|
|
||||||
if err := wsjson.Write(ctx, conn, stats); err != nil {
|
|
||||||
U.LogError(r).Msg("failed to write JSON")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var startTime = time.Now()
|
var startTime = time.Now()
|
||||||
|
|
68
internal/api/v1/utils/ws.go
Normal file
68
internal/api/v1/utils/ws.go
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func warnNoMatchDomains() {
|
||||||
|
logging.Warn().Msg("no match domains configured, accepting websocket API request from all origins")
|
||||||
|
}
|
||||||
|
|
||||||
|
var warnNoMatchDomainOnce sync.Once
|
||||||
|
|
||||||
|
func InitiateWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
|
||||||
|
var originPats []string
|
||||||
|
|
||||||
|
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
|
||||||
|
|
||||||
|
if len(cfg.Value().MatchDomains) == 0 {
|
||||||
|
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
|
||||||
|
originPats = []string{"*"}
|
||||||
|
} else {
|
||||||
|
originPats = make([]string, len(cfg.Value().MatchDomains))
|
||||||
|
for i, domain := range cfg.Value().MatchDomains {
|
||||||
|
originPats[i] = "*" + domain
|
||||||
|
}
|
||||||
|
originPats = append(originPats, localAddresses...)
|
||||||
|
}
|
||||||
|
if common.IsDebug {
|
||||||
|
originPats = []string{"*"}
|
||||||
|
}
|
||||||
|
return websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: originPats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func PeriodicWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
|
||||||
|
conn, err := InitiateWS(cfg, w, r)
|
||||||
|
if err != nil {
|
||||||
|
HandleErr(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
/* trunk-ignore(golangci-lint/errcheck) */
|
||||||
|
defer conn.CloseNow()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-cfg.Context().Done():
|
||||||
|
return
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := do(conn); err != nil {
|
||||||
|
HandleErr(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue