diff --git a/internal/net/http/header_utils.go b/internal/net/http/header_utils.go index ffd0f22..f78677f 100644 --- a/internal/net/http/header_utils.go +++ b/internal/net/http/header_utils.go @@ -4,6 +4,16 @@ import ( "net/http" ) +const ( + HeaderXForwardedMethod = "X-Forwarded-Method" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedProto = "X-Forwarded-Proto" + HeaderXForwardedHost = "X-Forwarded-Host" + HeaderXForwardedPort = "X-Forwarded-Port" + HeaderXForwardedURI = "X-Forwarded-Uri" + HeaderXRealIP = "X-Real-IP" +) + func RemoveHop(h http.Header) { reqUpType := UpgradeType(h) RemoveHopByHopHeaders(h) diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 7892c77..28ed657 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -53,6 +53,13 @@ func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) { func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { gphttp.RemoveHop(req.Header) + // Construct original URL for the redirect + // scheme := "http" + // if req.TLS != nil { + // scheme = "https" + // } + // originalURL := scheme + "://" + req.Host + req.RequestURI + url := fa.Address faReq, err := http.NewRequestWithContext( req.Context(), @@ -71,6 +78,8 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders) fa.setAuthHeaders(req, faReq) + // Set headers needed by Authentik + // faReq.Header.Set("X-Original-URL", originalURL) fa.m.AddTraceRequest("forward auth request", faReq) faResp, err := faHTTPClient.Do(faReq) @@ -100,7 +109,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req return } else if redirectURL.String() != "" { w.Header().Set("Location", redirectURL.String()) - fa.m.AddTracef("redirect to %q", redirectURL.String()).WithResponse(faResp) + fa.m.AddTracef("%s", "redirect to "+redirectURL.String()) } w.WriteHeader(faResp.StatusCode) @@ -160,54 +169,54 @@ func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) { func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) { if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if fa.TrustForwardHeader { - if prior, ok := req.Header[xForwardedFor]; ok { + if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } } - faReq.Header.Set(xForwardedFor, clientIP) + faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP) } - xMethod := req.Header.Get(xForwardedMethod) + xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod) switch { case xMethod != "" && fa.TrustForwardHeader: - faReq.Header.Set(xForwardedMethod, xMethod) + faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod) case req.Method != "": - faReq.Header.Set(xForwardedMethod, req.Method) + faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) default: - faReq.Header.Del(xForwardedMethod) + faReq.Header.Del(gphttp.HeaderXForwardedMethod) } - xfp := req.Header.Get(xForwardedProto) + xfp := req.Header.Get(gphttp.HeaderXForwardedProto) switch { case xfp != "" && fa.TrustForwardHeader: - faReq.Header.Set(xForwardedProto, xfp) + faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp) case req.TLS != nil: - faReq.Header.Set(xForwardedProto, "https") + faReq.Header.Set(gphttp.HeaderXForwardedProto, "https") default: - faReq.Header.Set(xForwardedProto, "http") + faReq.Header.Set(gphttp.HeaderXForwardedProto, "http") } - if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader { - faReq.Header.Set(xForwardedPort, xfp) + if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader { + faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp) } - xfh := req.Header.Get(xForwardedHost) + xfh := req.Header.Get(gphttp.HeaderXForwardedHost) switch { case xfh != "" && fa.TrustForwardHeader: - faReq.Header.Set(xForwardedHost, xfh) + faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh) case req.Host != "": - faReq.Header.Set(xForwardedHost, req.Host) + faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) default: - faReq.Header.Del(xForwardedHost) + faReq.Header.Del(gphttp.HeaderXForwardedHost) } - xfURI := req.Header.Get(xForwardedURI) + xfURI := req.Header.Get(gphttp.HeaderXForwardedURI) switch { case xfURI != "" && fa.TrustForwardHeader: - faReq.Header.Set(xForwardedURI, xfURI) + faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI) case req.URL.RequestURI() != "": - faReq.Header.Set(xForwardedURI, req.URL.RequestURI()) + faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI()) default: - faReq.Header.Del(xForwardedURI) + faReq.Header.Del(gphttp.HeaderXForwardedURI) } } diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index bdf0879..9d460b3 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -4,6 +4,7 @@ import ( "net" E "github.com/yusing/go-proxy/internal/error" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/types" ) @@ -110,7 +111,7 @@ func (ri *realIP) setRealIP(req *Request) { req.RemoteAddr = lastNonTrustedIP req.Header.Set(ri.Header, lastNonTrustedIP) - req.Header.Set("X-Real-IP", lastNonTrustedIP) - req.Header.Set(xForwardedFor, lastNonTrustedIP) + req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP) + req.Header.Set(gphttp.HeaderXForwardedFor, lastNonTrustedIP) ri.m.AddTracef("set real ip %s", lastNonTrustedIP) } diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index 67457b8..89a85e1 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -6,13 +6,14 @@ import ( "strings" "testing" + gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestSetRealIPOpts(t *testing.T) { opts := OptionsRaw{ - "header": "X-Real-IP", + "header": gphttp.HeaderXRealIP, "from": []string{ "127.0.0.0/8", "192.168.0.0/16", @@ -21,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) { "recursive": true, } optExpected := &realIPOpts{ - Header: "X-Real-IP", + Header: gphttp.HeaderXRealIP, From: []*types.CIDR{ { IP: net.ParseIP("127.0.0.0"), @@ -50,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) { func TestSetRealIP(t *testing.T) { const ( - testHeader = "X-Real-IP" + testHeader = gphttp.HeaderXRealIP testRealIP = "192.168.1.1" ) opts := OptionsRaw{ @@ -73,5 +74,5 @@ func TestSetRealIP(t *testing.T) { t.Log(traces) ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) - ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP) + ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP) } diff --git a/internal/net/http/middleware/x_forwarded.go b/internal/net/http/middleware/x_forwarded.go index 2c73e26..9728214 100644 --- a/internal/net/http/middleware/x_forwarded.go +++ b/internal/net/http/middleware/x_forwarded.go @@ -3,48 +3,26 @@ package middleware import ( "net" "strings" -) -const ( - xForwardedFor = "X-Forwarded-For" - xForwardedMethod = "X-Forwarded-Method" - xForwardedHost = "X-Forwarded-Host" - xForwardedProto = "X-Forwarded-Proto" - xForwardedURI = "X-Forwarded-Uri" - xForwardedPort = "X-Forwarded-Port" + gphttp "github.com/yusing/go-proxy/internal/net/http" ) var SetXForwarded = &Middleware{ before: Rewrite(func(req *Request) { - delXForwarded(req) + req.Header.Del(gphttp.HeaderXForwardedFor) clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err == nil { - req.Header.Set(xForwardedFor, clientIP) - } else { - req.Header.Set(xForwardedFor, req.RemoteAddr) - } - req.Header.Set(xForwardedHost, req.Host) - if req.TLS == nil { - req.Header.Set(xForwardedProto, "http") - } else { - req.Header.Set(xForwardedProto, "https") + req.Header.Set(gphttp.HeaderXForwardedFor, clientIP) } }), } var HideXForwarded = &Middleware{ - before: Rewrite(delXForwarded), -} - -func delXForwarded(req *Request) { - req.Header.Del("Forwarded") - toRemove := make([]string, 0) - for k := range req.Header { - if strings.HasPrefix(k, "X-Forwarded-") { - toRemove = append(toRemove, k) + before: Rewrite(func(req *Request) { + for k := range req.Header { + if strings.HasPrefix(k, "X-Forwarded-") { + req.Header.Del(k) + } } - } - for _, k := range toRemove { - req.Header.Del(k) - } + }), } diff --git a/internal/net/http/modify_response_writer.go b/internal/net/http/modify_response_writer.go index 6a8122f..4da0d20 100644 --- a/internal/net/http/modify_response_writer.go +++ b/internal/net/http/modify_response_writer.go @@ -34,6 +34,10 @@ func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyRes } } +func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter { + return w.w +} + func (w *ModifyResponseWriter) WriteHeader(code int) { if w.headerSent { return diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index 6a4f18f..d06f42b 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() { metrics.GetRouteMetrics().UnregisterService(p.TargetName) } -func rewriteRequestURL(req *http.Request, target *url.URL) { +func rewriteRequestURL(req *http.Request, target types.URL) { targetQuery := target.RawQuery req.URL.Scheme = target.Scheme req.URL.Host = target.Host - req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL) if targetQuery == "" || req.URL.RawQuery == "" { req.URL.RawQuery = targetQuery + req.URL.RawQuery } else { @@ -346,7 +346,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate } - rewriteRequestURL(outreq, p.TargetURL.URL) + rewriteRequestURL(outreq, p.TargetURL) outreq.Close = false reqUpType := UpgradeType(outreq.Header) @@ -355,6 +355,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { return } + req.Header.Del("Forwarded") RemoveHopByHopHeaders(outreq.Header) // Issue 21096: tell backend applications that care about trailer support @@ -386,14 +387,19 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { outreq.Header.Set("X-Forwarded-For", clientIP) } } - if req.TLS == nil { - outreq.Header.Set("X-Forwarded-Proto", "http") - outreq.Header.Set("X-Forwarded-Scheme", "http") + + var reqScheme string + if req.TLS != nil { + reqScheme = "https" } else { - outreq.Header.Set("X-Forwarded-Proto", "https") - outreq.Header.Set("X-Forwarded-Scheme", "https") + reqScheme = "http" } - outreq.Header.Set("X-Forwarded-Host", req.Host) + + outreq.Header.Set(HeaderXForwardedMethod, req.Method) + outreq.Header.Set(HeaderXForwardedProto, reqScheme) + outreq.Header.Set(HeaderXForwardedHost, req.Host) + outreq.Header.Set(HeaderXForwardedURI, req.RequestURI) + outreq.Header.Set("Origin", reqScheme+"://"+req.Host) if _, ok := outreq.Header["User-Agent"]; !ok { // If the outbound request doesn't have a User-Agent header set,