diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go index 7549240..cb2af9d 100644 --- a/agent/pkg/handler/handler.go +++ b/agent/pkg/handler/handler.go @@ -49,7 +49,7 @@ func NewAgentHandler() http.Handler { fmt.Fprint(w, env.AgentName) }) mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth) - mux.HandleMethods("GET", agent.EndpointLogs, memlogger.LogsWS(nil)) + mux.HandleMethods("GET", agent.EndpointLogs, memlogger.HandlerFunc()) mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP) mux.ServeMux.HandleFunc("/", DockerSocketHandler()) return mux diff --git a/internal/api/handler.go b/internal/api/handler.go index 94909c2..1683d4a 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "net/http" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -12,39 +13,73 @@ import ( "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/metrics/uptime" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/utils/strutils" ) -type ServeMux struct{ *http.ServeMux } +type ( + ServeMux struct { + *http.ServeMux + cfg config.ConfigInstance + } + WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request) +) + +func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) { + var handler http.HandlerFunc + switch h := h.(type) { + case func(http.ResponseWriter, *http.Request): + handler = h + case http.Handler: + handler = h.ServeHTTP + case WithCfgHandler: + handler = func(w http.ResponseWriter, r *http.Request) { + h(mux.cfg, w, r) + } + default: + panic(fmt.Errorf("unsupported handler type: %T", h)) + } + + matchDomains := mux.cfg.Value().MatchDomains + if len(matchDomains) > 0 { + origHandler := handler + handler = func(w http.ResponseWriter, r *http.Request) { + if httpheaders.IsWebsocket(r.Header) { + httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains) + } + origHandler(w, r) + } + } + + if len(requireAuth) > 0 && requireAuth[0] { + handler = auth.RequireAuth(handler) + } -func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) { for _, m := range strutils.CommaSeperatedList(methods) { mux.ServeMux.HandleFunc(m+" "+endpoint, handler) } } func NewHandler(cfg config.ConfigInstance) http.Handler { - mux := ServeMux{http.NewServeMux()} + mux := ServeMux{http.NewServeMux(), cfg} mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) - mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) - mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) - mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) - mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) - mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) - mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) - mux.HandleFunc("POST", "/v1/file/validate/{type}", auth.RequireAuth(v1.ValidateFile)) - mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) - mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) - mux.HandleFunc("GET", "/v1/health/ws", auth.RequireAuth(useCfg(cfg, v1.HealthWS))) - mux.HandleFunc("GET", "/v1/logs/ws", auth.RequireAuth(memlogger.LogsWS(cfg.Value().MatchDomains))) - mux.HandleFunc("GET", "/v1/favicon", auth.RequireAuth(favicon.GetFavIcon)) - mux.HandleFunc("POST", "/v1/homepage/set", auth.RequireAuth(v1.SetHomePageOverrides)) - mux.HandleFunc("GET", "/v1/agents/ws", auth.RequireAuth(useCfg(cfg, v1.AgentsWS))) - mux.HandleFunc("GET", "/v1/metrics/system_info", auth.RequireAuth(useCfg(cfg, v1.SystemInfo))) - mux.HandleFunc("GET", "/v1/metrics/system_info/ws", auth.RequireAuth(useCfg(cfg, v1.SystemInfo))) - mux.HandleFunc("GET", "/v1/metrics/uptime", auth.RequireAuth(uptime.Poller.ServeHTTP)) - mux.HandleFunc("GET", "/v1/metrics/uptime/ws", auth.RequireAuth(useWS(cfg, uptime.Poller.ServeWS))) + + mux.HandleFunc("GET", "/v1/stats", v1.Stats, true) + mux.HandleFunc("POST", "/v1/reload", v1.Reload, true) + mux.HandleFunc("GET", "/v1/list", v1.List, true) + mux.HandleFunc("GET", "/v1/list/{what}", v1.List, true) + mux.HandleFunc("GET", "/v1/list/{what}/{which}", v1.List, true) + mux.HandleFunc("GET", "/v1/file/{type}/{filename}", v1.GetFileContent, true) + mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", v1.SetFileContent, true) + mux.HandleFunc("POST", "/v1/file/validate/{type}", v1.ValidateFile, true) + mux.HandleFunc("GET", "/v1/health", v1.Health, true) + mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true) + mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true) + mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true) + mux.HandleFunc("GET", "/v1/agents", v1.AgentsWS, true) + mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true) + mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true) if common.PrometheusEnabled { mux.Handle("GET /v1/metrics", promhttp.Handler()) @@ -69,15 +104,3 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { } return mux } - -func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - handler(cfg, w, r) - } -} - -func useWS(cfg config.ConfigInstance, handler func(allowedDomains []string, w http.ResponseWriter, r *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - handler(cfg.Value().MatchDomains, w, r) - } -} diff --git a/internal/api/v1/agents.go b/internal/api/v1/agents.go index a758591..4c39d12 100644 --- a/internal/api/v1/agents.go +++ b/internal/api/v1/agents.go @@ -8,11 +8,16 @@ import ( "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - U.PeriodicWS(cfg.Value().MatchDomains, w, r, 10*time.Second, func(conn *websocket.Conn) error { - wsjson.Write(r.Context(), conn, cfg.ListAgents()) - return nil - }) + if httpheaders.IsWebsocket(r.Header) { + U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error { + wsjson.Write(r.Context(), conn, cfg.ListAgents()) + return nil + }) + } else { + U.RespondJSON(w, r, cfg.ListAgents()) + } } diff --git a/internal/api/v1/health.go b/internal/api/v1/health.go index ddb47b2..2856aaf 100644 --- a/internal/api/v1/health.go +++ b/internal/api/v1/health.go @@ -7,12 +7,16 @@ import ( "github.com/coder/websocket" "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" - config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/route/routes/routequery" ) -func HealthWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, routequery.HealthMap()) - }) +func Health(w http.ResponseWriter, r *http.Request) { + if httpheaders.IsWebsocket(r.Header) { + U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error { + return wsjson.Write(r.Context(), conn, routequery.HealthMap()) + }) + } else { + U.RespondJSON(w, r, routequery.HealthMap()) + } } diff --git a/internal/api/v1/stats.go b/internal/api/v1/stats.go index d2be9c7..ffdece9 100644 --- a/internal/api/v1/stats.go +++ b/internal/api/v1/stats.go @@ -8,17 +8,18 @@ import ( "github.com/coder/websocket/wsjson" U "github.com/yusing/go-proxy/internal/api/v1/utils" config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/utils/strutils" ) func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - U.RespondJSON(w, r, getStats(cfg)) -} - -func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, getStats(cfg)) - }) + if httpheaders.IsWebsocket(r.Header) { + U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error { + return wsjson.Write(r.Context(), conn, getStats(cfg)) + }) + } else { + U.RespondJSON(w, r, getStats(cfg)) + } } var startTime = time.Now() diff --git a/internal/api/v1/system_info.go b/internal/api/v1/system_info.go index bc8e914..22d9bd4 100644 --- a/internal/api/v1/system_info.go +++ b/internal/api/v1/system_info.go @@ -2,24 +2,22 @@ package v1 import ( "net/http" - "strings" "github.com/coder/websocket/wsjson" agentPkg "github.com/yusing/go-proxy/agent/pkg/agent" U "github.com/yusing/go-proxy/internal/api/v1/utils" config "github.com/yusing/go-proxy/internal/config/types" + E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/metrics/systeminfo" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { - isWS := strings.HasSuffix(r.URL.Path, "/ws") - agentName := r.URL.Query().Get("agent_name") + query := r.URL.Query() + agentName := query.Get("agent_name") + query.Del("agent_name") if agentName == "" { - if isWS { - systeminfo.Poller.ServeWS(cfg.Value().MatchDomains, w, r) - } else { - systeminfo.Poller.ServeHTTP(w, r) - } + systeminfo.Poller.ServeHTTP(w, r) return } @@ -28,10 +26,12 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques U.HandleErr(w, r, U.ErrInvalidKey("agent_name"), http.StatusNotFound) return } + + isWS := httpheaders.IsWebsocket(r.Header) if !isWS { respData, status, err := agent.Forward(r, agentPkg.EndpointSystemInfo) if err != nil { - U.HandleErr(w, r, err) + U.HandleErr(w, r, E.Wrap(err, "failed to forward request to agent")) return } if status != http.StatusOK { @@ -40,14 +40,16 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques } U.WriteBody(w, respData) } else { - clientConn, err := U.InitiateWS(cfg.Value().MatchDomains, w, r) + r = r.WithContext(r.Context()) + clientConn, err := U.InitiateWS(w, r) if err != nil { - U.HandleErr(w, r, err) + U.HandleErr(w, r, E.Wrap(err, "failed to initiate websocket")) return } - agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo) + defer clientConn.CloseNow() + agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo+"?"+query.Encode()) if err != nil { - U.HandleErr(w, r, err) + U.HandleErr(w, r, E.Wrap(err, "failed to connect to agent with websocket")) return } //nolint:errcheck @@ -63,7 +65,7 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques err = wsjson.Write(r.Context(), clientConn, data) } if err != nil { - U.HandleErr(w, r, err) + U.HandleErr(w, r, E.Wrap(err, "failed to write data to client")) return } } diff --git a/internal/api/v1/utils/ws.go b/internal/api/v1/utils/ws.go index 2927198..aa2f985 100644 --- a/internal/api/v1/utils/ws.go +++ b/internal/api/v1/utils/ws.go @@ -8,6 +8,7 @@ import ( "github.com/coder/websocket" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) func warnNoMatchDomains() { @@ -16,11 +17,12 @@ func warnNoMatchDomains() { var warnNoMatchDomainOnce sync.Once -func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { +func InitiateWS(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { var originPats []string localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"} + allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header) if len(allowedDomains) == 0 || common.IsDebug { warnNoMatchDomainOnce.Do(warnNoMatchDomains) originPats = []string{"*"} @@ -40,8 +42,8 @@ func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) }) } -func PeriodicWS(allowedDomains []string, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { - conn, err := InitiateWS(allowedDomains, w, r) +func PeriodicWS(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) { + conn, err := InitiateWS(w, r) if err != nil { HandleErr(w, r, err) return diff --git a/internal/logging/memlogger/mem_logger.go b/internal/logging/memlogger/mem_logger.go index 39790bb..5904b84 100644 --- a/internal/logging/memlogger/mem_logger.go +++ b/internal/logging/memlogger/mem_logger.go @@ -80,16 +80,18 @@ func init() { logging.InitLogger(zerolog.MultiLevelWriter(os.Stderr, memLoggerInstance)) } -func LogsWS(allowedDomains []string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - memLoggerInstance.ServeHTTP(allowedDomains, w, r) - } -} - func GetMemLogger() MemLogger { return memLoggerInstance } +func Handler() http.Handler { + return memLoggerInstance +} + +func HandlerFunc() http.HandlerFunc { + return memLoggerInstance.ServeHTTP +} + func (m *memLogger) truncateIfNeeded(n int) { m.RLock() needTruncate := m.Len()+n > maxMemLogSize @@ -150,8 +152,8 @@ func (m *memLogger) Write(p []byte) (n int, err error) { return } -func (m *memLogger) ServeHTTP(allowedDomains []string, w http.ResponseWriter, r *http.Request) { - conn, err := utils.InitiateWS(allowedDomains, w, r) +func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) { + conn, err := utils.InitiateWS(w, r) if err != nil { utils.HandleErr(w, r, err) return diff --git a/internal/metrics/period/handler.go b/internal/metrics/period/handler.go index 07b3852..d86f5c4 100644 --- a/internal/metrics/period/handler.go +++ b/internal/metrics/period/handler.go @@ -1,6 +1,7 @@ package period import ( + "errors" "net/http" "time" @@ -8,48 +9,22 @@ import ( "github.com/coder/websocket/wsjson" "github.com/yusing/go-proxy/internal/api/v1/utils" metricsutils "github.com/yusing/go-proxy/internal/metrics/utils" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) -func (p *Poller[T, AggregateT]) lastResultHandler(w http.ResponseWriter, r *http.Request) { - info := p.GetLastResult() - if info == nil { - http.Error(w, "no system info", http.StatusNoContent) - return - } - utils.RespondJSON(w, r, info) -} - +// ServeHTTP serves the data for the given period. +// +// If the period is not specified, it serves the last result. +// +// If the period is specified, it serves the data for the given period. +// +// If the period is invalid, it returns a 400 error. +// +// If the data is not found, it returns a 204 error. +// +// If the request is a websocket request, it serves the data for the given period for every interval. func (p *Poller[T, AggregateT]) ServeHTTP(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() - period := query.Get("period") - if period == "" { - p.lastResultHandler(w, r) - return - } - periodFilter := Filter(period) - if !periodFilter.IsValid() { - http.Error(w, "invalid period", http.StatusBadRequest) - return - } - rangeData := p.Get(periodFilter) - if len(rangeData) == 0 { - http.Error(w, "no data", http.StatusNoContent) - return - } - if p.aggregator != nil { - total, aggregated := p.aggregator(rangeData, query) - utils.RespondJSON(w, r, map[string]any{ - "total": total, - "data": aggregated, - }) - } else { - utils.RespondJSON(w, r, rangeData) - } -} - -func (p *Poller[T, AggregateT]) ServeWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - period := query.Get("period") interval := metricsutils.QueryDuration(query, "interval", 0) minInterval := 1 * time.Second @@ -60,28 +35,52 @@ func (p *Poller[T, AggregateT]) ServeWS(allowedDomains []string, w http.Response interval = minInterval } - if period == "" { - utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, p.GetLastResult()) + if httpheaders.IsWebsocket(r.Header) { + utils.PeriodicWS(w, r, interval, func(conn *websocket.Conn) error { + data, err := p.getRespData(r) + if err != nil { + return err + } + if data == nil { + return nil + } + return wsjson.Write(r.Context(), conn, data) }) } else { - periodFilter := Filter(period) - if !periodFilter.IsValid() { - http.Error(w, "invalid period", http.StatusBadRequest) + data, err := p.getRespData(r) + if err != nil { + utils.HandleErr(w, r, err) return } - if p.aggregator != nil { - utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error { - total, aggregated := p.aggregator(p.Get(periodFilter), query) - return wsjson.Write(r.Context(), conn, map[string]any{ - "total": total, - "data": aggregated, - }) - }) - } else { - utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error { - return wsjson.Write(r.Context(), conn, p.Get(periodFilter)) - }) + if data == nil { + http.Error(w, "no data", http.StatusNoContent) + return } + utils.RespondJSON(w, r, data) + } +} + +func (p *Poller[T, AggregateT]) getRespData(r *http.Request) (any, error) { + query := r.URL.Query() + period := query.Get("period") + if period == "" { + return p.GetLastResult(), nil + } + periodFilter := Filter(period) + if !periodFilter.IsValid() { + return nil, errors.New("invalid period") + } + rangeData := p.Get(periodFilter) + if len(rangeData) == 0 { + return nil, nil + } + if p.aggregator != nil { + total, aggregated := p.aggregator(rangeData, query) + return map[string]any{ + "total": total, + "data": aggregated, + }, nil + } else { + return rangeData, nil } } diff --git a/internal/net/http/header_utils.go b/internal/net/http/httpheaders/utils.go similarity index 99% rename from internal/net/http/header_utils.go rename to internal/net/http/httpheaders/utils.go index db8c78f..20e892b 100644 --- a/internal/net/http/header_utils.go +++ b/internal/net/http/httpheaders/utils.go @@ -1,4 +1,4 @@ -package http +package httpheaders import ( "net/http" diff --git a/internal/net/http/httpheaders/websocket.go b/internal/net/http/httpheaders/websocket.go new file mode 100644 index 0000000..755d324 --- /dev/null +++ b/internal/net/http/httpheaders/websocket.go @@ -0,0 +1,21 @@ +package httpheaders + +import ( + "net/http" +) + +const ( + HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains" +) + +func WebsocketAllowedDomains(h http.Header) []string { + return h[HeaderXGoDoxyWebsocketAllowedDomains] +} + +func SetWebsocketAllowedDomains(h http.Header, domains []string) { + h[HeaderXGoDoxyWebsocketAllowedDomains] = domains +} + +func IsWebsocket(h http.Header) bool { + return UpgradeType(h) == "websocket" +} diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index 538c418..2c50f02 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -10,6 +10,7 @@ import ( "github.com/yusing/go-proxy/internal/logging" gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/net/http/middleware/errorpage" ) @@ -36,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error { resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(errorPage)) resp.ContentLength = int64(len(errorPage)) - resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage))) - resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8") + resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage))) + resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8") } else { logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode) } @@ -61,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo ext := filepath.Ext(filename) switch ext { case ".html": - w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8") + w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8") case ".js": - w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8") + w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8") case ".css": - w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8") + w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8") default: logging.Error().Msgf("unexpected file type %q for %s", ext, filename) } diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 0b5a53d..f1925fb 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -4,7 +4,7 @@ import ( "net" "net/http" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/net/types" ) @@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) { req.RemoteAddr = lastNonTrustedIP req.Header.Set(ri.Header, lastNonTrustedIP) - req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) + req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP) ri.AddTracef("set real ip %s", lastNonTrustedIP) } diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index 02f5bd5..5b4d809 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -6,14 +6,14 @@ import ( "strings" "testing" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestSetRealIPOpts(t *testing.T) { opts := OptionsRaw{ - "header": gphttp.HeaderXRealIP, + "header": httpheaders.HeaderXRealIP, "from": []string{ "127.0.0.0/8", "192.168.0.0/16", @@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) { "recursive": true, } optExpected := &RealIPOpts{ - Header: gphttp.HeaderXRealIP, + Header: httpheaders.HeaderXRealIP, From: []*types.CIDR{ { IP: net.ParseIP("127.0.0.0"), @@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) { func TestSetRealIP(t *testing.T) { const ( - testHeader = gphttp.HeaderXRealIP + testHeader = httpheaders.HeaderXRealIP testRealIP = "192.168.1.1" ) opts := OptionsRaw{ diff --git a/internal/net/http/middleware/set_upstream_headers.go b/internal/net/http/middleware/set_upstream_headers.go index 009fc84..488b1c5 100644 --- a/internal/net/http/middleware/set_upstream_headers.go +++ b/internal/net/http/middleware/set_upstream_headers.go @@ -3,7 +3,7 @@ package middleware import ( "net/http" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/net/http/reverseproxy" ) @@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware { // before implements RequestModifier. func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - r.Header.Set(gphttp.HeaderUpstreamName, s.Name) - r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme) - r.Header.Set(gphttp.HeaderUpstreamHost, s.Host) - r.Header.Set(gphttp.HeaderUpstreamPort, s.Port) + r.Header.Set(httpheaders.HeaderUpstreamName, s.Name) + r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme) + r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host) + r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port) return true } diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index c3b0c73..a4ebcc7 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -4,7 +4,7 @@ import ( "net/http" "sync" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) type ( @@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace { return nil } tr.URL = req.RequestURI - tr.ReqHeaders = gphttp.HeaderToMap(req.Header) + tr.ReqHeaders = httpheaders.HeaderToMap(req.Header) return tr } @@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace { return nil } tr.URL = resp.Request.RequestURI - tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header) - tr.RespHeaders = gphttp.HeaderToMap(resp.Header) + tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header) + tr.RespHeaders = httpheaders.HeaderToMap(resp.Header) tr.RespStatus = resp.StatusCode return tr } diff --git a/internal/net/http/middleware/vars.go b/internal/net/http/middleware/vars.go index 0830542..20e716c 100644 --- a/internal/net/http/middleware/vars.go +++ b/internal/net/http/middleware/vars.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) type ( @@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ return "" }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, - VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) }, - VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, - VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, - VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, + VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) }, + VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) }, + VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) }, + VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) }, VarUpstreamAddr: func(req *http.Request) string { - upHost := req.Header.Get(gphttp.HeaderUpstreamHost) - upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) + upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) if upPort != "" { return upHost + ":" + upPort } return upHost }, VarUpstreamURL: func(req *http.Request) string { - upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) + upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme) if upScheme == "" { return "" } - upHost := req.Header.Get(gphttp.HeaderUpstreamHost) - upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) + upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) upAddr := upHost if upPort != "" { upAddr += ":" + upPort diff --git a/internal/net/http/middleware/x_forwarded.go b/internal/net/http/middleware/x_forwarded.go index ff8a558..7958a3d 100644 --- a/internal/net/http/middleware/x_forwarded.go +++ b/internal/net/http/middleware/x_forwarded.go @@ -5,7 +5,7 @@ import ( "net/http" "strings" - gphttp "github.com/yusing/go-proxy/internal/net/http" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" ) type ( @@ -20,10 +20,10 @@ var ( // before implements RequestModifier. func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - r.Header.Del(gphttp.HeaderXForwardedFor) + r.Header.Del(httpheaders.HeaderXForwardedFor) clientIP, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil { - r.Header.Set(gphttp.HeaderXForwardedFor, clientIP) + r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP) } return true } diff --git a/internal/net/http/reverseproxy/reverse_proxy_mod.go b/internal/net/http/reverseproxy/reverse_proxy_mod.go index eb3985c..6d85eb0 100644 --- a/internal/net/http/reverseproxy/reverse_proxy_mod.go +++ b/internal/net/http/reverseproxy/reverse_proxy_mod.go @@ -26,8 +26,8 @@ import ( "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/logging" - gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" "golang.org/x/net/http/httpguts" @@ -266,14 +266,14 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { p.rewriteRequestURL(outreq) outreq.Close = false - reqUpType := gphttp.UpgradeType(outreq.Header) + reqUpType := httpheaders.UpgradeType(outreq.Header) if !IsPrint(reqUpType) { p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true) return } req.Header.Del("Forwarded") - gphttp.RemoveHopByHopHeaders(outreq.Header) + httpheaders.RemoveHopByHopHeaders(outreq.Header) // Issue 21096: tell backend applications that care about trailer support // that we support trailers. (We do, but we don't go out of our way to @@ -298,7 +298,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - prior, ok := outreq.Header[gphttp.HeaderXForwardedFor] + prior, ok := outreq.Header[httpheaders.HeaderXForwardedFor] omit := ok && prior == nil // Issue 38079: nil now means don't populate the header xff, _, err := net.SplitHostPort(req.RemoteAddr) @@ -309,7 +309,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { xff = strings.Join(prior, ", ") + ", " + xff } if !omit { - outreq.Header.Set(gphttp.HeaderXForwardedFor, xff) + outreq.Header.Set(httpheaders.HeaderXForwardedFor, xff) } var reqScheme string @@ -319,10 +319,10 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { reqScheme = "http" } - outreq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) - outreq.Header.Set(gphttp.HeaderXForwardedProto, reqScheme) - outreq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) - outreq.Header.Set(gphttp.HeaderXForwardedURI, req.RequestURI) + outreq.Header.Set(httpheaders.HeaderXForwardedMethod, req.Method) + outreq.Header.Set(httpheaders.HeaderXForwardedProto, reqScheme) + outreq.Header.Set(httpheaders.HeaderXForwardedHost, req.Host) + outreq.Header.Set(httpheaders.HeaderXForwardedURI, req.RequestURI) if _, ok := outreq.Header["User-Agent"]; !ok { // If the outbound request doesn't have a User-Agent header set, @@ -389,7 +389,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { return } - gphttp.RemoveHopByHopHeaders(res.Header) + httpheaders.RemoveHopByHopHeaders(res.Header) if !p.modifyResponse(rw, res, req, outreq) { return @@ -460,8 +460,8 @@ func cleanWebsocketHeaders(req *http.Request) { } func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := gphttp.UpgradeType(req.Header) - resUpType := gphttp.UpgradeType(res.Header) + reqUpType := httpheaders.UpgradeType(req.Header) + resUpType := httpheaders.UpgradeType(res.Header) if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) return