refactor header utils to httpheader package, cleanup api endpoints

This commit is contained in:
yusing 2025-02-13 07:32:59 +08:00
parent 5c9083a5df
commit 02d1c9ce98
19 changed files with 237 additions and 177 deletions

View file

@ -49,7 +49,7 @@ func NewAgentHandler() http.Handler {
fmt.Fprint(w, env.AgentName) fmt.Fprint(w, env.AgentName)
}) })
mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth) mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth)
mux.HandleMethods("GET", agent.EndpointLogs, memlogger.LogsWS(nil)) mux.HandleMethods("GET", agent.EndpointLogs, memlogger.HandlerFunc())
mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP) mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", DockerSocketHandler()) mux.ServeMux.HandleFunc("/", DockerSocketHandler())
return mux return mux

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"fmt"
"net/http" "net/http"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
@ -12,39 +13,73 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/uptime" "github.com/yusing/go-proxy/internal/metrics/uptime"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ServeMux struct{ *http.ServeMux } type (
ServeMux struct {
*http.ServeMux
cfg config.ConfigInstance
}
WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request)
)
func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) {
var handler http.HandlerFunc
switch h := h.(type) {
case func(http.ResponseWriter, *http.Request):
handler = h
case http.Handler:
handler = h.ServeHTTP
case WithCfgHandler:
handler = func(w http.ResponseWriter, r *http.Request) {
h(mux.cfg, w, r)
}
default:
panic(fmt.Errorf("unsupported handler type: %T", h))
}
matchDomains := mux.cfg.Value().MatchDomains
if len(matchDomains) > 0 {
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}
}
if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler)
}
func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) {
for _, m := range strutils.CommaSeperatedList(methods) { for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler) mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
} }
} }
func NewHandler(cfg config.ConfigInstance) http.Handler { func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := ServeMux{http.NewServeMux()} mux := ServeMux{http.NewServeMux(), cfg}
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/stats", v1.Stats, true)
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("POST", "/v1/reload", v1.Reload, true)
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list", v1.List, true)
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) mux.HandleFunc("GET", "/v1/list/{what}", v1.List, true)
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("GET", "/v1/list/{what}/{which}", v1.List, true)
mux.HandleFunc("POST", "/v1/file/validate/{type}", auth.RequireAuth(v1.ValidateFile)) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", v1.GetFileContent, true)
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", v1.SetFileContent, true)
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) mux.HandleFunc("POST", "/v1/file/validate/{type}", v1.ValidateFile, true)
mux.HandleFunc("GET", "/v1/health/ws", auth.RequireAuth(useCfg(cfg, v1.HealthWS))) mux.HandleFunc("GET", "/v1/health", v1.Health, true)
mux.HandleFunc("GET", "/v1/logs/ws", auth.RequireAuth(memlogger.LogsWS(cfg.Value().MatchDomains))) mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true)
mux.HandleFunc("GET", "/v1/favicon", auth.RequireAuth(favicon.GetFavIcon)) mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true)
mux.HandleFunc("POST", "/v1/homepage/set", auth.RequireAuth(v1.SetHomePageOverrides)) mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true)
mux.HandleFunc("GET", "/v1/agents/ws", auth.RequireAuth(useCfg(cfg, v1.AgentsWS))) mux.HandleFunc("GET", "/v1/agents", v1.AgentsWS, true)
mux.HandleFunc("GET", "/v1/metrics/system_info", auth.RequireAuth(useCfg(cfg, v1.SystemInfo))) mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true)
mux.HandleFunc("GET", "/v1/metrics/system_info/ws", auth.RequireAuth(useCfg(cfg, v1.SystemInfo))) mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true)
mux.HandleFunc("GET", "/v1/metrics/uptime", auth.RequireAuth(uptime.Poller.ServeHTTP))
mux.HandleFunc("GET", "/v1/metrics/uptime/ws", auth.RequireAuth(useWS(cfg, uptime.Poller.ServeWS)))
if common.PrometheusEnabled { if common.PrometheusEnabled {
mux.Handle("GET /v1/metrics", promhttp.Handler()) mux.Handle("GET /v1/metrics", promhttp.Handler())
@ -69,15 +104,3 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
} }
return mux return mux
} }
func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(cfg, w, r)
}
}
func useWS(cfg config.ConfigInstance, handler func(allowedDomains []string, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(cfg.Value().MatchDomains, w, r)
}
}

View file

@ -8,11 +8,16 @@ import (
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 10*time.Second, func(conn *websocket.Conn) error { if httpheaders.IsWebsocket(r.Header) {
wsjson.Write(r.Context(), conn, cfg.ListAgents()) U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error {
return nil wsjson.Write(r.Context(), conn, cfg.ListAgents())
}) return nil
})
} else {
U.RespondJSON(w, r, cfg.ListAgents())
}
} }

View file

@ -7,12 +7,16 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/route/routes/routequery" "github.com/yusing/go-proxy/internal/route/routes/routequery"
) )
func HealthWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error { if httpheaders.IsWebsocket(r.Header) {
return wsjson.Write(r.Context(), conn, routequery.HealthMap()) U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error {
}) return wsjson.Write(r.Context(), conn, routequery.HealthMap())
})
} else {
U.RespondJSON(w, r, routequery.HealthMap())
}
} }

View file

@ -8,17 +8,18 @@ import (
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats(cfg)) if httpheaders.IsWebsocket(r.Header) {
} U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, getStats(cfg))
func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { })
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error { } else {
return wsjson.Write(r.Context(), conn, getStats(cfg)) U.RespondJSON(w, r, getStats(cfg))
}) }
} }
var startTime = time.Now() var startTime = time.Now()

View file

@ -2,24 +2,22 @@ package v1
import ( import (
"net/http" "net/http"
"strings"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
agentPkg "github.com/yusing/go-proxy/agent/pkg/agent" agentPkg "github.com/yusing/go-proxy/agent/pkg/agent"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/metrics/systeminfo" "github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
isWS := strings.HasSuffix(r.URL.Path, "/ws") query := r.URL.Query()
agentName := r.URL.Query().Get("agent_name") agentName := query.Get("agent_name")
query.Del("agent_name")
if agentName == "" { if agentName == "" {
if isWS { systeminfo.Poller.ServeHTTP(w, r)
systeminfo.Poller.ServeWS(cfg.Value().MatchDomains, w, r)
} else {
systeminfo.Poller.ServeHTTP(w, r)
}
return return
} }
@ -28,10 +26,12 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
U.HandleErr(w, r, U.ErrInvalidKey("agent_name"), http.StatusNotFound) U.HandleErr(w, r, U.ErrInvalidKey("agent_name"), http.StatusNotFound)
return return
} }
isWS := httpheaders.IsWebsocket(r.Header)
if !isWS { if !isWS {
respData, status, err := agent.Forward(r, agentPkg.EndpointSystemInfo) respData, status, err := agent.Forward(r, agentPkg.EndpointSystemInfo)
if err != nil { if err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, E.Wrap(err, "failed to forward request to agent"))
return return
} }
if status != http.StatusOK { if status != http.StatusOK {
@ -40,14 +40,16 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
} }
U.WriteBody(w, respData) U.WriteBody(w, respData)
} else { } else {
clientConn, err := U.InitiateWS(cfg.Value().MatchDomains, w, r) r = r.WithContext(r.Context())
clientConn, err := U.InitiateWS(w, r)
if err != nil { if err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, E.Wrap(err, "failed to initiate websocket"))
return return
} }
agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo) defer clientConn.CloseNow()
agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo+"?"+query.Encode())
if err != nil { if err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, E.Wrap(err, "failed to connect to agent with websocket"))
return return
} }
//nolint:errcheck //nolint:errcheck
@ -63,7 +65,7 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
err = wsjson.Write(r.Context(), clientConn, data) err = wsjson.Write(r.Context(), clientConn, data)
} }
if err != nil { if err != nil {
U.HandleErr(w, r, err) U.HandleErr(w, r, E.Wrap(err, "failed to write data to client"))
return return
} }
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
func warnNoMatchDomains() { func warnNoMatchDomains() {
@ -16,11 +17,12 @@ func warnNoMatchDomains() {
var warnNoMatchDomainOnce sync.Once var warnNoMatchDomainOnce sync.Once
func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { func InitiateWS(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 || common.IsDebug { if len(allowedDomains) == 0 || common.IsDebug {
warnNoMatchDomainOnce.Do(warnNoMatchDomains) warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"} originPats = []string{"*"}
@ -40,8 +42,8 @@ func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request)
}) })
} }
func PeriodicWS(allowedDomains []string, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { func PeriodicWS(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := InitiateWS(allowedDomains, w, r) conn, err := InitiateWS(w, r)
if err != nil { if err != nil {
HandleErr(w, r, err) HandleErr(w, r, err)
return return

View file

@ -80,16 +80,18 @@ func init() {
logging.InitLogger(zerolog.MultiLevelWriter(os.Stderr, memLoggerInstance)) logging.InitLogger(zerolog.MultiLevelWriter(os.Stderr, memLoggerInstance))
} }
func LogsWS(allowedDomains []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
memLoggerInstance.ServeHTTP(allowedDomains, w, r)
}
}
func GetMemLogger() MemLogger { func GetMemLogger() MemLogger {
return memLoggerInstance return memLoggerInstance
} }
func Handler() http.Handler {
return memLoggerInstance
}
func HandlerFunc() http.HandlerFunc {
return memLoggerInstance.ServeHTTP
}
func (m *memLogger) truncateIfNeeded(n int) { func (m *memLogger) truncateIfNeeded(n int) {
m.RLock() m.RLock()
needTruncate := m.Len()+n > maxMemLogSize needTruncate := m.Len()+n > maxMemLogSize
@ -150,8 +152,8 @@ func (m *memLogger) Write(p []byte) (n int, err error) {
return return
} }
func (m *memLogger) ServeHTTP(allowedDomains []string, w http.ResponseWriter, r *http.Request) { func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := utils.InitiateWS(allowedDomains, w, r) conn, err := utils.InitiateWS(w, r)
if err != nil { if err != nil {
utils.HandleErr(w, r, err) utils.HandleErr(w, r, err)
return return

View file

@ -1,6 +1,7 @@
package period package period
import ( import (
"errors"
"net/http" "net/http"
"time" "time"
@ -8,48 +9,22 @@ import (
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/api/v1/utils"
metricsutils "github.com/yusing/go-proxy/internal/metrics/utils" metricsutils "github.com/yusing/go-proxy/internal/metrics/utils"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
func (p *Poller[T, AggregateT]) lastResultHandler(w http.ResponseWriter, r *http.Request) { // ServeHTTP serves the data for the given period.
info := p.GetLastResult() //
if info == nil { // If the period is not specified, it serves the last result.
http.Error(w, "no system info", http.StatusNoContent) //
return // If the period is specified, it serves the data for the given period.
} //
utils.RespondJSON(w, r, info) // If the period is invalid, it returns a 400 error.
} //
// If the data is not found, it returns a 204 error.
//
// If the request is a websocket request, it serves the data for the given period for every interval.
func (p *Poller[T, AggregateT]) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (p *Poller[T, AggregateT]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query() query := r.URL.Query()
period := query.Get("period")
if period == "" {
p.lastResultHandler(w, r)
return
}
periodFilter := Filter(period)
if !periodFilter.IsValid() {
http.Error(w, "invalid period", http.StatusBadRequest)
return
}
rangeData := p.Get(periodFilter)
if len(rangeData) == 0 {
http.Error(w, "no data", http.StatusNoContent)
return
}
if p.aggregator != nil {
total, aggregated := p.aggregator(rangeData, query)
utils.RespondJSON(w, r, map[string]any{
"total": total,
"data": aggregated,
})
} else {
utils.RespondJSON(w, r, rangeData)
}
}
func (p *Poller[T, AggregateT]) ServeWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
period := query.Get("period")
interval := metricsutils.QueryDuration(query, "interval", 0) interval := metricsutils.QueryDuration(query, "interval", 0)
minInterval := 1 * time.Second minInterval := 1 * time.Second
@ -60,28 +35,52 @@ func (p *Poller[T, AggregateT]) ServeWS(allowedDomains []string, w http.Response
interval = minInterval interval = minInterval
} }
if period == "" { if httpheaders.IsWebsocket(r.Header) {
utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error { utils.PeriodicWS(w, r, interval, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, p.GetLastResult()) data, err := p.getRespData(r)
if err != nil {
return err
}
if data == nil {
return nil
}
return wsjson.Write(r.Context(), conn, data)
}) })
} else { } else {
periodFilter := Filter(period) data, err := p.getRespData(r)
if !periodFilter.IsValid() { if err != nil {
http.Error(w, "invalid period", http.StatusBadRequest) utils.HandleErr(w, r, err)
return return
} }
if p.aggregator != nil { if data == nil {
utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error { http.Error(w, "no data", http.StatusNoContent)
total, aggregated := p.aggregator(p.Get(periodFilter), query) return
return wsjson.Write(r.Context(), conn, map[string]any{
"total": total,
"data": aggregated,
})
})
} else {
utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, p.Get(periodFilter))
})
} }
utils.RespondJSON(w, r, data)
}
}
func (p *Poller[T, AggregateT]) getRespData(r *http.Request) (any, error) {
query := r.URL.Query()
period := query.Get("period")
if period == "" {
return p.GetLastResult(), nil
}
periodFilter := Filter(period)
if !periodFilter.IsValid() {
return nil, errors.New("invalid period")
}
rangeData := p.Get(periodFilter)
if len(rangeData) == 0 {
return nil, nil
}
if p.aggregator != nil {
total, aggregated := p.aggregator(rangeData, query)
return map[string]any{
"total": total,
"data": aggregated,
}, nil
} else {
return rangeData, nil
} }
} }

View file

@ -1,4 +1,4 @@
package http package httpheaders
import ( import (
"net/http" "net/http"

View file

@ -0,0 +1,21 @@
package httpheaders
import (
"net/http"
)
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func IsWebsocket(h http.Header) bool {
return UpgradeType(h) == "websocket"
}

View file

@ -10,6 +10,7 @@ import (
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
) )
@ -36,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
resp.Body.Close() resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage)) resp.ContentLength = int64(len(errorPage))
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage))) resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8") resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else { } else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode) logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
} }
@ -61,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
ext := filepath.Ext(filename) ext := filepath.Ext(filename)
switch ext { switch ext {
case ".html": case ".html":
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8") w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
case ".js": case ".js":
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8") w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
case ".css": case ".css":
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8") w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default: default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename) logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
} }

View file

@ -4,7 +4,7 @@ import (
"net" "net"
"net/http" "net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
req.RemoteAddr = lastNonTrustedIP req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP) req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP) ri.AddTracef("set real ip %s", lastNonTrustedIP)
} }

View file

@ -6,14 +6,14 @@ import (
"strings" "strings"
"testing" "testing"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestSetRealIPOpts(t *testing.T) { func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{ opts := OptionsRaw{
"header": gphttp.HeaderXRealIP, "header": httpheaders.HeaderXRealIP,
"from": []string{ "from": []string{
"127.0.0.0/8", "127.0.0.0/8",
"192.168.0.0/16", "192.168.0.0/16",
@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
"recursive": true, "recursive": true,
} }
optExpected := &RealIPOpts{ optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP, Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{ From: []*types.CIDR{
{ {
IP: net.ParseIP("127.0.0.0"), IP: net.ParseIP("127.0.0.0"),
@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
func TestSetRealIP(t *testing.T) { func TestSetRealIP(t *testing.T) {
const ( const (
testHeader = gphttp.HeaderXRealIP testHeader = httpheaders.HeaderXRealIP
testRealIP = "192.168.1.1" testRealIP = "192.168.1.1"
) )
opts := OptionsRaw{ opts := OptionsRaw{

View file

@ -3,7 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy" "github.com/yusing/go-proxy/internal/net/http/reverseproxy"
) )
@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
// before implements RequestModifier. // before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamName, s.Name) r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme) r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host) r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port) r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true return true
} }

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"sync" "sync"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
type ( type (
@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace {
return nil return nil
} }
tr.URL = req.RequestURI tr.URL = req.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(req.Header) tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr return tr
} }
@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace {
return nil return nil
} }
tr.URL = resp.Request.RequestURI tr.URL = resp.Request.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header) tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = gphttp.HeaderToMap(resp.Header) tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode tr.RespStatus = resp.StatusCode
return tr return tr
} }

View file

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
type ( type (
@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
return "" return ""
}, },
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) }, VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string { VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort) upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" { if upPort != "" {
return upHost + ":" + upPort return upHost + ":" + upPort
} }
return upHost return upHost
}, },
VarUpstreamURL: func(req *http.Request) string { VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" { if upScheme == "" {
return "" return ""
} }
upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort) upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost upAddr := upHost
if upPort != "" { if upPort != "" {
upAddr += ":" + upPort upAddr += ":" + upPort

View file

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"strings" "strings"
gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
type ( type (
@ -20,10 +20,10 @@ var (
// before implements RequestModifier. // before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor) r.Header.Del(httpheaders.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr) clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil { if err == nil {
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP) r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
} }
return true return true
} }

View file

@ -26,8 +26,8 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/accesslog" "github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
@ -266,14 +266,14 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
p.rewriteRequestURL(outreq) p.rewriteRequestURL(outreq)
outreq.Close = false outreq.Close = false
reqUpType := gphttp.UpgradeType(outreq.Header) reqUpType := httpheaders.UpgradeType(outreq.Header)
if !IsPrint(reqUpType) { if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true) p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
return return
} }
req.Header.Del("Forwarded") req.Header.Del("Forwarded")
gphttp.RemoveHopByHopHeaders(outreq.Header) httpheaders.RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support // Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to // that we support trailers. (We do, but we don't go out of our way to
@ -298,7 +298,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
// If we aren't the first proxy retain prior // If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space // X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one. // separated list and fold multiple headers into one.
prior, ok := outreq.Header[gphttp.HeaderXForwardedFor] prior, ok := outreq.Header[httpheaders.HeaderXForwardedFor]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
xff, _, err := net.SplitHostPort(req.RemoteAddr) xff, _, err := net.SplitHostPort(req.RemoteAddr)
@ -309,7 +309,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
xff = strings.Join(prior, ", ") + ", " + xff xff = strings.Join(prior, ", ") + ", " + xff
} }
if !omit { if !omit {
outreq.Header.Set(gphttp.HeaderXForwardedFor, xff) outreq.Header.Set(httpheaders.HeaderXForwardedFor, xff)
} }
var reqScheme string var reqScheme string
@ -319,10 +319,10 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
reqScheme = "http" reqScheme = "http"
} }
outreq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) outreq.Header.Set(httpheaders.HeaderXForwardedMethod, req.Method)
outreq.Header.Set(gphttp.HeaderXForwardedProto, reqScheme) outreq.Header.Set(httpheaders.HeaderXForwardedProto, reqScheme)
outreq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) outreq.Header.Set(httpheaders.HeaderXForwardedHost, req.Host)
outreq.Header.Set(gphttp.HeaderXForwardedURI, req.RequestURI) outreq.Header.Set(httpheaders.HeaderXForwardedURI, req.RequestURI)
if _, ok := outreq.Header["User-Agent"]; !ok { if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set, // If the outbound request doesn't have a User-Agent header set,
@ -389,7 +389,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return return
} }
gphttp.RemoveHopByHopHeaders(res.Header) httpheaders.RemoveHopByHopHeaders(res.Header)
if !p.modifyResponse(rw, res, req, outreq) { if !p.modifyResponse(rw, res, req, outreq) {
return return
@ -460,8 +460,8 @@ func cleanWebsocketHeaders(req *http.Request) {
} }
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := gphttp.UpgradeType(req.Header) reqUpType := httpheaders.UpgradeType(req.Header)
resUpType := gphttp.UpgradeType(res.Header) resUpType := httpheaders.UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
return return