refactor header utils to httpheader package, cleanup api endpoints

This commit is contained in:
yusing 2025-02-13 07:32:59 +08:00
parent 5c9083a5df
commit 02d1c9ce98
19 changed files with 237 additions and 177 deletions

View file

@ -49,7 +49,7 @@ func NewAgentHandler() http.Handler {
fmt.Fprint(w, env.AgentName)
})
mux.HandleMethods("GET", agent.EndpointHealth, CheckHealth)
mux.HandleMethods("GET", agent.EndpointLogs, memlogger.LogsWS(nil))
mux.HandleMethods("GET", agent.EndpointLogs, memlogger.HandlerFunc())
mux.HandleMethods("GET", agent.EndpointSystemInfo, systeminfo.Poller.ServeHTTP)
mux.ServeMux.HandleFunc("/", DockerSocketHandler())
return mux

View file

@ -1,6 +1,7 @@
package api
import (
"fmt"
"net/http"
"github.com/prometheus/client_golang/prometheus/promhttp"
@ -12,39 +13,73 @@ import (
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/uptime"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
type ServeMux struct{ *http.ServeMux }
type (
ServeMux struct {
*http.ServeMux
cfg config.ConfigInstance
}
WithCfgHandler = func(config.ConfigInstance, http.ResponseWriter, *http.Request)
)
func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...bool) {
var handler http.HandlerFunc
switch h := h.(type) {
case func(http.ResponseWriter, *http.Request):
handler = h
case http.Handler:
handler = h.ServeHTTP
case WithCfgHandler:
handler = func(w http.ResponseWriter, r *http.Request) {
h(mux.cfg, w, r)
}
default:
panic(fmt.Errorf("unsupported handler type: %T", h))
}
matchDomains := mux.cfg.Value().MatchDomains
if len(matchDomains) > 0 {
origHandler := handler
handler = func(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
httpheaders.SetWebsocketAllowedDomains(r.Header, matchDomains)
}
origHandler(w, r)
}
}
if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler)
}
func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
}
func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := ServeMux{http.NewServeMux()}
mux := ServeMux{http.NewServeMux(), cfg}
mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("POST", "/v1/file/validate/{type}", auth.RequireAuth(v1.ValidateFile))
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
mux.HandleFunc("GET", "/v1/health/ws", auth.RequireAuth(useCfg(cfg, v1.HealthWS)))
mux.HandleFunc("GET", "/v1/logs/ws", auth.RequireAuth(memlogger.LogsWS(cfg.Value().MatchDomains)))
mux.HandleFunc("GET", "/v1/favicon", auth.RequireAuth(favicon.GetFavIcon))
mux.HandleFunc("POST", "/v1/homepage/set", auth.RequireAuth(v1.SetHomePageOverrides))
mux.HandleFunc("GET", "/v1/agents/ws", auth.RequireAuth(useCfg(cfg, v1.AgentsWS)))
mux.HandleFunc("GET", "/v1/metrics/system_info", auth.RequireAuth(useCfg(cfg, v1.SystemInfo)))
mux.HandleFunc("GET", "/v1/metrics/system_info/ws", auth.RequireAuth(useCfg(cfg, v1.SystemInfo)))
mux.HandleFunc("GET", "/v1/metrics/uptime", auth.RequireAuth(uptime.Poller.ServeHTTP))
mux.HandleFunc("GET", "/v1/metrics/uptime/ws", auth.RequireAuth(useWS(cfg, uptime.Poller.ServeWS)))
mux.HandleFunc("GET", "/v1/stats", v1.Stats, true)
mux.HandleFunc("POST", "/v1/reload", v1.Reload, true)
mux.HandleFunc("GET", "/v1/list", v1.List, true)
mux.HandleFunc("GET", "/v1/list/{what}", v1.List, true)
mux.HandleFunc("GET", "/v1/list/{what}/{which}", v1.List, true)
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", v1.GetFileContent, true)
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", v1.SetFileContent, true)
mux.HandleFunc("POST", "/v1/file/validate/{type}", v1.ValidateFile, true)
mux.HandleFunc("GET", "/v1/health", v1.Health, true)
mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true)
mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true)
mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true)
mux.HandleFunc("GET", "/v1/agents", v1.AgentsWS, true)
mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true)
mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true)
if common.PrometheusEnabled {
mux.Handle("GET /v1/metrics", promhttp.Handler())
@ -69,15 +104,3 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
}
return mux
}
func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(cfg, w, r)
}
}
func useWS(cfg config.ConfigInstance, handler func(allowedDomains []string, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(cfg.Value().MatchDomains, w, r)
}
}

View file

@ -8,11 +8,16 @@ import (
"github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 10*time.Second, func(conn *websocket.Conn) error {
if httpheaders.IsWebsocket(r.Header) {
U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error {
wsjson.Write(r.Context(), conn, cfg.ListAgents())
return nil
})
} else {
U.RespondJSON(w, r, cfg.ListAgents())
}
}

View file

@ -7,12 +7,16 @@ import (
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/route/routes/routequery"
)
func HealthWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error {
func Health(w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) {
U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, routequery.HealthMap())
})
} else {
U.RespondJSON(w, r, routequery.HealthMap())
}
}

View file

@ -8,17 +8,18 @@ import (
"github.com/coder/websocket/wsjson"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func Stats(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, getStats(cfg))
}
func StatsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.PeriodicWS(cfg.Value().MatchDomains, w, r, 1*time.Second, func(conn *websocket.Conn) error {
if httpheaders.IsWebsocket(r.Header) {
U.PeriodicWS(w, r, 1*time.Second, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, getStats(cfg))
})
} else {
U.RespondJSON(w, r, getStats(cfg))
}
}
var startTime = time.Now()

View file

@ -2,24 +2,22 @@ package v1
import (
"net/http"
"strings"
"github.com/coder/websocket/wsjson"
agentPkg "github.com/yusing/go-proxy/agent/pkg/agent"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
isWS := strings.HasSuffix(r.URL.Path, "/ws")
agentName := r.URL.Query().Get("agent_name")
query := r.URL.Query()
agentName := query.Get("agent_name")
query.Del("agent_name")
if agentName == "" {
if isWS {
systeminfo.Poller.ServeWS(cfg.Value().MatchDomains, w, r)
} else {
systeminfo.Poller.ServeHTTP(w, r)
}
return
}
@ -28,10 +26,12 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
U.HandleErr(w, r, U.ErrInvalidKey("agent_name"), http.StatusNotFound)
return
}
isWS := httpheaders.IsWebsocket(r.Header)
if !isWS {
respData, status, err := agent.Forward(r, agentPkg.EndpointSystemInfo)
if err != nil {
U.HandleErr(w, r, err)
U.HandleErr(w, r, E.Wrap(err, "failed to forward request to agent"))
return
}
if status != http.StatusOK {
@ -40,14 +40,16 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
}
U.WriteBody(w, respData)
} else {
clientConn, err := U.InitiateWS(cfg.Value().MatchDomains, w, r)
r = r.WithContext(r.Context())
clientConn, err := U.InitiateWS(w, r)
if err != nil {
U.HandleErr(w, r, err)
U.HandleErr(w, r, E.Wrap(err, "failed to initiate websocket"))
return
}
agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo)
defer clientConn.CloseNow()
agentConn, _, err := agent.Websocket(r.Context(), agentPkg.EndpointSystemInfo+"?"+query.Encode())
if err != nil {
U.HandleErr(w, r, err)
U.HandleErr(w, r, E.Wrap(err, "failed to connect to agent with websocket"))
return
}
//nolint:errcheck
@ -63,7 +65,7 @@ func SystemInfo(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Reques
err = wsjson.Write(r.Context(), clientConn, data)
}
if err != nil {
U.HandleErr(w, r, err)
U.HandleErr(w, r, E.Wrap(err, "failed to write data to client"))
return
}
}

View file

@ -8,6 +8,7 @@ import (
"github.com/coder/websocket"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
func warnNoMatchDomains() {
@ -16,11 +17,12 @@ func warnNoMatchDomains() {
var warnNoMatchDomainOnce sync.Once
func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
func InitiateWS(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var originPats []string
localAddresses := []string{"127.0.0.1", "10.0.*.*", "172.16.*.*", "192.168.*.*"}
allowedDomains := httpheaders.WebsocketAllowedDomains(r.Header)
if len(allowedDomains) == 0 || common.IsDebug {
warnNoMatchDomainOnce.Do(warnNoMatchDomains)
originPats = []string{"*"}
@ -40,8 +42,8 @@ func InitiateWS(allowedDomains []string, w http.ResponseWriter, r *http.Request)
})
}
func PeriodicWS(allowedDomains []string, w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := InitiateWS(allowedDomains, w, r)
func PeriodicWS(w http.ResponseWriter, r *http.Request, interval time.Duration, do func(conn *websocket.Conn) error) {
conn, err := InitiateWS(w, r)
if err != nil {
HandleErr(w, r, err)
return

View file

@ -80,16 +80,18 @@ func init() {
logging.InitLogger(zerolog.MultiLevelWriter(os.Stderr, memLoggerInstance))
}
func LogsWS(allowedDomains []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
memLoggerInstance.ServeHTTP(allowedDomains, w, r)
}
}
func GetMemLogger() MemLogger {
return memLoggerInstance
}
func Handler() http.Handler {
return memLoggerInstance
}
func HandlerFunc() http.HandlerFunc {
return memLoggerInstance.ServeHTTP
}
func (m *memLogger) truncateIfNeeded(n int) {
m.RLock()
needTruncate := m.Len()+n > maxMemLogSize
@ -150,8 +152,8 @@ func (m *memLogger) Write(p []byte) (n int, err error) {
return
}
func (m *memLogger) ServeHTTP(allowedDomains []string, w http.ResponseWriter, r *http.Request) {
conn, err := utils.InitiateWS(allowedDomains, w, r)
func (m *memLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := utils.InitiateWS(w, r)
if err != nil {
utils.HandleErr(w, r, err)
return

View file

@ -1,6 +1,7 @@
package period
import (
"errors"
"net/http"
"time"
@ -8,48 +9,22 @@ import (
"github.com/coder/websocket/wsjson"
"github.com/yusing/go-proxy/internal/api/v1/utils"
metricsutils "github.com/yusing/go-proxy/internal/metrics/utils"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
func (p *Poller[T, AggregateT]) lastResultHandler(w http.ResponseWriter, r *http.Request) {
info := p.GetLastResult()
if info == nil {
http.Error(w, "no system info", http.StatusNoContent)
return
}
utils.RespondJSON(w, r, info)
}
// ServeHTTP serves the data for the given period.
//
// If the period is not specified, it serves the last result.
//
// If the period is specified, it serves the data for the given period.
//
// If the period is invalid, it returns a 400 error.
//
// If the data is not found, it returns a 204 error.
//
// If the request is a websocket request, it serves the data for the given period for every interval.
func (p *Poller[T, AggregateT]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
period := query.Get("period")
if period == "" {
p.lastResultHandler(w, r)
return
}
periodFilter := Filter(period)
if !periodFilter.IsValid() {
http.Error(w, "invalid period", http.StatusBadRequest)
return
}
rangeData := p.Get(periodFilter)
if len(rangeData) == 0 {
http.Error(w, "no data", http.StatusNoContent)
return
}
if p.aggregator != nil {
total, aggregated := p.aggregator(rangeData, query)
utils.RespondJSON(w, r, map[string]any{
"total": total,
"data": aggregated,
})
} else {
utils.RespondJSON(w, r, rangeData)
}
}
func (p *Poller[T, AggregateT]) ServeWS(allowedDomains []string, w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
period := query.Get("period")
interval := metricsutils.QueryDuration(query, "interval", 0)
minInterval := 1 * time.Second
@ -60,28 +35,52 @@ func (p *Poller[T, AggregateT]) ServeWS(allowedDomains []string, w http.Response
interval = minInterval
}
if period == "" {
utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, p.GetLastResult())
if httpheaders.IsWebsocket(r.Header) {
utils.PeriodicWS(w, r, interval, func(conn *websocket.Conn) error {
data, err := p.getRespData(r)
if err != nil {
return err
}
if data == nil {
return nil
}
return wsjson.Write(r.Context(), conn, data)
})
} else {
periodFilter := Filter(period)
if !periodFilter.IsValid() {
http.Error(w, "invalid period", http.StatusBadRequest)
data, err := p.getRespData(r)
if err != nil {
utils.HandleErr(w, r, err)
return
}
if p.aggregator != nil {
utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error {
total, aggregated := p.aggregator(p.Get(periodFilter), query)
return wsjson.Write(r.Context(), conn, map[string]any{
"total": total,
"data": aggregated,
})
})
} else {
utils.PeriodicWS(allowedDomains, w, r, interval, func(conn *websocket.Conn) error {
return wsjson.Write(r.Context(), conn, p.Get(periodFilter))
})
if data == nil {
http.Error(w, "no data", http.StatusNoContent)
return
}
utils.RespondJSON(w, r, data)
}
}
func (p *Poller[T, AggregateT]) getRespData(r *http.Request) (any, error) {
query := r.URL.Query()
period := query.Get("period")
if period == "" {
return p.GetLastResult(), nil
}
periodFilter := Filter(period)
if !periodFilter.IsValid() {
return nil, errors.New("invalid period")
}
rangeData := p.Get(periodFilter)
if len(rangeData) == 0 {
return nil, nil
}
if p.aggregator != nil {
total, aggregated := p.aggregator(rangeData, query)
return map[string]any{
"total": total,
"data": aggregated,
}, nil
} else {
return rangeData, nil
}
}

View file

@ -1,4 +1,4 @@
package http
package httpheaders
import (
"net/http"

View file

@ -0,0 +1,21 @@
package httpheaders
import (
"net/http"
)
const (
HeaderXGoDoxyWebsocketAllowedDomains = "X-GoDoxy-Websocket-Allowed-Domains"
)
func WebsocketAllowedDomains(h http.Header) []string {
return h[HeaderXGoDoxyWebsocketAllowedDomains]
}
func SetWebsocketAllowedDomains(h http.Header, domains []string) {
h[HeaderXGoDoxyWebsocketAllowedDomains] = domains
}
func IsWebsocket(h http.Header) bool {
return UpgradeType(h) == "websocket"
}

View file

@ -10,6 +10,7 @@ import (
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/http/middleware/errorpage"
)
@ -36,8 +37,8 @@ func (customErrorPage) modifyResponse(resp *http.Response) error {
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(errorPage))
resp.ContentLength = int64(len(errorPage))
resp.Header.Set(gphttp.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
resp.Header.Set(httpheaders.HeaderContentLength, strconv.Itoa(len(errorPage)))
resp.Header.Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
} else {
logging.Error().Msgf("unable to load error page for status %d", resp.StatusCode)
}
@ -61,11 +62,11 @@ func ServeStaticErrorPageFile(w http.ResponseWriter, r *http.Request) (served bo
ext := filepath.Ext(filename)
switch ext {
case ".html":
w.Header().Set(gphttp.HeaderContentType, "text/html; charset=utf-8")
w.Header().Set(httpheaders.HeaderContentType, "text/html; charset=utf-8")
case ".js":
w.Header().Set(gphttp.HeaderContentType, "application/javascript; charset=utf-8")
w.Header().Set(httpheaders.HeaderContentType, "application/javascript; charset=utf-8")
case ".css":
w.Header().Set(gphttp.HeaderContentType, "text/css; charset=utf-8")
w.Header().Set(httpheaders.HeaderContentType, "text/css; charset=utf-8")
default:
logging.Error().Msgf("unexpected file type %q for %s", ext, filename)
}

View file

@ -4,7 +4,7 @@ import (
"net"
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
)
@ -111,6 +111,6 @@ func (ri *realIP) setRealIP(req *http.Request) {
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
req.Header.Set(httpheaders.HeaderXRealIP, lastNonTrustedIP)
ri.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -6,14 +6,14 @@ import (
"strings"
"testing"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": gphttp.HeaderXRealIP,
"header": httpheaders.HeaderXRealIP,
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
@ -22,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
"recursive": true,
}
optExpected := &RealIPOpts{
Header: gphttp.HeaderXRealIP,
Header: httpheaders.HeaderXRealIP,
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
@ -51,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
func TestSetRealIP(t *testing.T) {
const (
testHeader = gphttp.HeaderXRealIP
testHeader = httpheaders.HeaderXRealIP
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{

View file

@ -3,7 +3,7 @@ package middleware
import (
"net/http"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/http/reverseproxy"
)
@ -29,9 +29,9 @@ func newSetUpstreamHeaders(rp *reverseproxy.ReverseProxy) *Middleware {
// before implements RequestModifier.
func (s setUpstreamHeaders) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Set(gphttp.HeaderUpstreamName, s.Name)
r.Header.Set(gphttp.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, s.Host)
r.Header.Set(gphttp.HeaderUpstreamPort, s.Port)
r.Header.Set(httpheaders.HeaderUpstreamName, s.Name)
r.Header.Set(httpheaders.HeaderUpstreamScheme, s.Scheme)
r.Header.Set(httpheaders.HeaderUpstreamHost, s.Host)
r.Header.Set(httpheaders.HeaderUpstreamPort, s.Port)
return true
}

View file

@ -4,7 +4,7 @@ import (
"net/http"
"sync"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
type (
@ -37,7 +37,7 @@ func (tr *Trace) WithRequest(req *http.Request) *Trace {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(req.Header)
tr.ReqHeaders = httpheaders.HeaderToMap(req.Header)
return tr
}
@ -46,8 +46,8 @@ func (tr *Trace) WithResponse(resp *http.Response) *Trace {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = gphttp.HeaderToMap(resp.Request.Header)
tr.RespHeaders = gphttp.HeaderToMap(resp.Header)
tr.ReqHeaders = httpheaders.HeaderToMap(resp.Request.Header)
tr.RespHeaders = httpheaders.HeaderToMap(resp.Header)
tr.RespStatus = resp.StatusCode
return tr
}

View file

@ -7,7 +7,7 @@ import (
"strconv"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
type (
@ -91,25 +91,25 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) },
VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) },
VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *http.Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
if upPort != "" {
return upHost + ":" + upPort
}
return upHost
},
VarUpstreamURL: func(req *http.Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme)
if upScheme == "" {
return ""
}
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upHost := req.Header.Get(httpheaders.HeaderUpstreamHost)
upPort := req.Header.Get(httpheaders.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort

View file

@ -5,7 +5,7 @@ import (
"net/http"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
)
type (
@ -20,10 +20,10 @@ var (
// before implements RequestModifier.
func (setXForwarded) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
r.Header.Del(gphttp.HeaderXForwardedFor)
r.Header.Del(httpheaders.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
r.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
r.Header.Set(httpheaders.HeaderXForwardedFor, clientIP)
}
return true
}

View file

@ -26,8 +26,8 @@ import (
"github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/logging"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/http/accesslog"
"github.com/yusing/go-proxy/internal/net/http/httpheaders"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
"golang.org/x/net/http/httpguts"
@ -266,14 +266,14 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
p.rewriteRequestURL(outreq)
outreq.Close = false
reqUpType := gphttp.UpgradeType(outreq.Header)
reqUpType := httpheaders.UpgradeType(outreq.Header)
if !IsPrint(reqUpType) {
p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true)
return
}
req.Header.Del("Forwarded")
gphttp.RemoveHopByHopHeaders(outreq.Header)
httpheaders.RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
@ -298,7 +298,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header[gphttp.HeaderXForwardedFor]
prior, ok := outreq.Header[httpheaders.HeaderXForwardedFor]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
xff, _, err := net.SplitHostPort(req.RemoteAddr)
@ -309,7 +309,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
xff = strings.Join(prior, ", ") + ", " + xff
}
if !omit {
outreq.Header.Set(gphttp.HeaderXForwardedFor, xff)
outreq.Header.Set(httpheaders.HeaderXForwardedFor, xff)
}
var reqScheme string
@ -319,10 +319,10 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
reqScheme = "http"
}
outreq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
outreq.Header.Set(gphttp.HeaderXForwardedProto, reqScheme)
outreq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
outreq.Header.Set(gphttp.HeaderXForwardedURI, req.RequestURI)
outreq.Header.Set(httpheaders.HeaderXForwardedMethod, req.Method)
outreq.Header.Set(httpheaders.HeaderXForwardedProto, reqScheme)
outreq.Header.Set(httpheaders.HeaderXForwardedHost, req.Host)
outreq.Header.Set(httpheaders.HeaderXForwardedURI, req.RequestURI)
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
@ -389,7 +389,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return
}
gphttp.RemoveHopByHopHeaders(res.Header)
httpheaders.RemoveHopByHopHeaders(res.Header)
if !p.modifyResponse(rw, res, req, outreq) {
return
@ -460,8 +460,8 @@ func cleanWebsocketHeaders(req *http.Request) {
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := gphttp.UpgradeType(req.Header)
resUpType := gphttp.UpgradeType(res.Header)
reqUpType := httpheaders.UpgradeType(req.Header)
resUpType := httpheaders.UpgradeType(res.Header)
if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true)
return