small update on reverse proxy and xforwarded middlewares

This commit is contained in:
yusing 2024-12-01 05:04:57 +08:00
parent a4f44348ef
commit 863bb3f474
7 changed files with 76 additions and 67 deletions

View file

@ -4,6 +4,16 @@ import (
"net/http" "net/http"
) )
const (
HeaderXForwardedMethod = "X-Forwarded-Method"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedHost = "X-Forwarded-Host"
HeaderXForwardedPort = "X-Forwarded-Port"
HeaderXForwardedURI = "X-Forwarded-Uri"
HeaderXRealIP = "X-Real-IP"
)
func RemoveHop(h http.Header) { func RemoveHop(h http.Header) {
reqUpType := UpgradeType(h) reqUpType := UpgradeType(h)
RemoveHopByHopHeaders(h) RemoveHopByHopHeaders(h)

View file

@ -53,6 +53,13 @@ func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gphttp.RemoveHop(req.Header) gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
// scheme := "http"
// if req.TLS != nil {
// scheme = "https"
// }
// originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address url := fa.Address
faReq, err := http.NewRequestWithContext( faReq, err := http.NewRequestWithContext(
req.Context(), req.Context(),
@ -71,6 +78,8 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders) faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq) fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq) fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq) faResp, err := faHTTPClient.Do(faReq)
@ -100,7 +109,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
return return
} else if redirectURL.String() != "" { } else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String()) w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String()).WithResponse(faResp) fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
} }
w.WriteHeader(faResp.StatusCode) w.WriteHeader(faResp.StatusCode)
@ -160,54 +169,54 @@ func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) { func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader { if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok { if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP clientIP = strings.Join(prior, ", ") + ", " + clientIP
} }
} }
faReq.Header.Set(xForwardedFor, clientIP) faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
} }
xMethod := req.Header.Get(xForwardedMethod) xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
switch { switch {
case xMethod != "" && fa.TrustForwardHeader: case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod) faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
case req.Method != "": case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method) faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
default: default:
faReq.Header.Del(xForwardedMethod) faReq.Header.Del(gphttp.HeaderXForwardedMethod)
} }
xfp := req.Header.Get(xForwardedProto) xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
switch { switch {
case xfp != "" && fa.TrustForwardHeader: case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp) faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
case req.TLS != nil: case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https") faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
default: default:
faReq.Header.Set(xForwardedProto, "http") faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
} }
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader { if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp) faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
} }
xfh := req.Header.Get(xForwardedHost) xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
switch { switch {
case xfh != "" && fa.TrustForwardHeader: case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh) faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
case req.Host != "": case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host) faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
default: default:
faReq.Header.Del(xForwardedHost) faReq.Header.Del(gphttp.HeaderXForwardedHost)
} }
xfURI := req.Header.Get(xForwardedURI) xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
switch { switch {
case xfURI != "" && fa.TrustForwardHeader: case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI) faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
case req.URL.RequestURI() != "": case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI()) faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
default: default:
faReq.Header.Del(xForwardedURI) faReq.Header.Del(gphttp.HeaderXForwardedURI)
} }
} }

View file

@ -4,6 +4,7 @@ import (
"net" "net"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
) )
@ -110,7 +111,7 @@ func (ri *realIP) setRealIP(req *Request) {
req.RemoteAddr = lastNonTrustedIP req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP) req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set("X-Real-IP", lastNonTrustedIP) req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
req.Header.Set(xForwardedFor, lastNonTrustedIP) req.Header.Set(gphttp.HeaderXForwardedFor, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP) ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
} }

View file

@ -6,13 +6,14 @@ import (
"strings" "strings"
"testing" "testing"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestSetRealIPOpts(t *testing.T) { func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{ opts := OptionsRaw{
"header": "X-Real-IP", "header": gphttp.HeaderXRealIP,
"from": []string{ "from": []string{
"127.0.0.0/8", "127.0.0.0/8",
"192.168.0.0/16", "192.168.0.0/16",
@ -21,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
"recursive": true, "recursive": true,
} }
optExpected := &realIPOpts{ optExpected := &realIPOpts{
Header: "X-Real-IP", Header: gphttp.HeaderXRealIP,
From: []*types.CIDR{ From: []*types.CIDR{
{ {
IP: net.ParseIP("127.0.0.0"), IP: net.ParseIP("127.0.0.0"),
@ -50,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
func TestSetRealIP(t *testing.T) { func TestSetRealIP(t *testing.T) {
const ( const (
testHeader = "X-Real-IP" testHeader = gphttp.HeaderXRealIP
testRealIP = "192.168.1.1" testRealIP = "192.168.1.1"
) )
opts := OptionsRaw{ opts := OptionsRaw{
@ -73,5 +74,5 @@ func TestSetRealIP(t *testing.T) {
t.Log(traces) t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP) ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
} }

View file

@ -3,48 +3,26 @@ package middleware
import ( import (
"net" "net"
"strings" "strings"
)
const ( gphttp "github.com/yusing/go-proxy/internal/net/http"
xForwardedFor = "X-Forwarded-For"
xForwardedMethod = "X-Forwarded-Method"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
xForwardedURI = "X-Forwarded-Uri"
xForwardedPort = "X-Forwarded-Port"
) )
var SetXForwarded = &Middleware{ var SetXForwarded = &Middleware{
before: Rewrite(func(req *Request) { before: Rewrite(func(req *Request) {
delXForwarded(req) req.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(req.RemoteAddr) clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil { if err == nil {
req.Header.Set(xForwardedFor, clientIP) req.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
} else {
req.Header.Set(xForwardedFor, req.RemoteAddr)
}
req.Header.Set(xForwardedHost, req.Host)
if req.TLS == nil {
req.Header.Set(xForwardedProto, "http")
} else {
req.Header.Set(xForwardedProto, "https")
} }
}), }),
} }
var HideXForwarded = &Middleware{ var HideXForwarded = &Middleware{
before: Rewrite(delXForwarded), before: Rewrite(func(req *Request) {
}
func delXForwarded(req *Request) {
req.Header.Del("Forwarded")
toRemove := make([]string, 0)
for k := range req.Header { for k := range req.Header {
if strings.HasPrefix(k, "X-Forwarded-") { if strings.HasPrefix(k, "X-Forwarded-") {
toRemove = append(toRemove, k)
}
}
for _, k := range toRemove {
req.Header.Del(k) req.Header.Del(k)
} }
} }
}),
}

View file

@ -34,6 +34,10 @@ func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyRes
} }
} }
func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter {
return w.w
}
func (w *ModifyResponseWriter) WriteHeader(code int) { func (w *ModifyResponseWriter) WriteHeader(code int) {
if w.headerSent { if w.headerSent {
return return

View file

@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() {
metrics.GetRouteMetrics().UnregisterService(p.TargetName) metrics.GetRouteMetrics().UnregisterService(p.TargetName)
} }
func rewriteRequestURL(req *http.Request, target *url.URL) { func rewriteRequestURL(req *http.Request, target types.URL) {
targetQuery := target.RawQuery targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme req.URL.Scheme = target.Scheme
req.URL.Host = target.Host req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" { if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else { } else {
@ -346,7 +346,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
} }
rewriteRequestURL(outreq, p.TargetURL.URL) rewriteRequestURL(outreq, p.TargetURL)
outreq.Close = false outreq.Close = false
reqUpType := UpgradeType(outreq.Header) reqUpType := UpgradeType(outreq.Header)
@ -355,6 +355,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return return
} }
req.Header.Del("Forwarded")
RemoveHopByHopHeaders(outreq.Header) RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support // Issue 21096: tell backend applications that care about trailer support
@ -386,14 +387,19 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Set("X-Forwarded-For", clientIP) outreq.Header.Set("X-Forwarded-For", clientIP)
} }
} }
if req.TLS == nil {
outreq.Header.Set("X-Forwarded-Proto", "http") var reqScheme string
outreq.Header.Set("X-Forwarded-Scheme", "http") if req.TLS != nil {
reqScheme = "https"
} else { } else {
outreq.Header.Set("X-Forwarded-Proto", "https") reqScheme = "http"
outreq.Header.Set("X-Forwarded-Scheme", "https")
} }
outreq.Header.Set("X-Forwarded-Host", req.Host)
outreq.Header.Set(HeaderXForwardedMethod, req.Method)
outreq.Header.Set(HeaderXForwardedProto, reqScheme)
outreq.Header.Set(HeaderXForwardedHost, req.Host)
outreq.Header.Set(HeaderXForwardedURI, req.RequestURI)
outreq.Header.Set("Origin", reqScheme+"://"+req.Host)
if _, ok := outreq.Header["User-Agent"]; !ok { if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set, // If the outbound request doesn't have a User-Agent header set,