small update on reverse proxy and xforwarded middlewares

This commit is contained in:
yusing 2024-12-01 05:04:57 +08:00
parent a4f44348ef
commit 863bb3f474
7 changed files with 76 additions and 67 deletions

View file

@ -4,6 +4,16 @@ import (
"net/http"
)
const (
HeaderXForwardedMethod = "X-Forwarded-Method"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedHost = "X-Forwarded-Host"
HeaderXForwardedPort = "X-Forwarded-Port"
HeaderXForwardedURI = "X-Forwarded-Uri"
HeaderXRealIP = "X-Real-IP"
)
func RemoveHop(h http.Header) {
reqUpType := UpgradeType(h)
RemoveHopByHopHeaders(h)

View file

@ -53,6 +53,13 @@ func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
// scheme := "http"
// if req.TLS != nil {
// scheme = "https"
// }
// originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
req.Context(),
@ -71,6 +78,8 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
@ -100,7 +109,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String()).WithResponse(faResp)
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
@ -160,54 +169,54 @@ func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(xForwardedFor, clientIP)
faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod)
faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method)
faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
default:
faReq.Header.Del(xForwardedMethod)
faReq.Header.Del(gphttp.HeaderXForwardedMethod)
}
xfp := req.Header.Get(xForwardedProto)
xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp)
faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https")
faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
default:
faReq.Header.Set(xForwardedProto, "http")
faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
}
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp)
if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
}
xfh := req.Header.Get(xForwardedHost)
xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh)
faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host)
faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
default:
faReq.Header.Del(xForwardedHost)
faReq.Header.Del(gphttp.HeaderXForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI)
faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(xForwardedURI)
faReq.Header.Del(gphttp.HeaderXForwardedURI)
}
}

View file

@ -4,6 +4,7 @@ import (
"net"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
)
@ -110,7 +111,7 @@ func (ri *realIP) setRealIP(req *Request) {
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set("X-Real-IP", lastNonTrustedIP)
req.Header.Set(xForwardedFor, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXRealIP, lastNonTrustedIP)
req.Header.Set(gphttp.HeaderXForwardedFor, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
}

View file

@ -6,13 +6,14 @@ import (
"strings"
"testing"
gphttp "github.com/yusing/go-proxy/internal/net/http"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": "X-Real-IP",
"header": gphttp.HeaderXRealIP,
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
@ -21,7 +22,7 @@ func TestSetRealIPOpts(t *testing.T) {
"recursive": true,
}
optExpected := &realIPOpts{
Header: "X-Real-IP",
Header: gphttp.HeaderXRealIP,
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
@ -50,7 +51,7 @@ func TestSetRealIPOpts(t *testing.T) {
func TestSetRealIP(t *testing.T) {
const (
testHeader = "X-Real-IP"
testHeader = gphttp.HeaderXRealIP
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{
@ -73,5 +74,5 @@ func TestSetRealIP(t *testing.T) {
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(gphttp.HeaderXForwardedFor), testRealIP)
}

View file

@ -3,48 +3,26 @@ package middleware
import (
"net"
"strings"
)
const (
xForwardedFor = "X-Forwarded-For"
xForwardedMethod = "X-Forwarded-Method"
xForwardedHost = "X-Forwarded-Host"
xForwardedProto = "X-Forwarded-Proto"
xForwardedURI = "X-Forwarded-Uri"
xForwardedPort = "X-Forwarded-Port"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
var SetXForwarded = &Middleware{
before: Rewrite(func(req *Request) {
delXForwarded(req)
req.Header.Del(gphttp.HeaderXForwardedFor)
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
req.Header.Set(xForwardedFor, clientIP)
} else {
req.Header.Set(xForwardedFor, req.RemoteAddr)
}
req.Header.Set(xForwardedHost, req.Host)
if req.TLS == nil {
req.Header.Set(xForwardedProto, "http")
} else {
req.Header.Set(xForwardedProto, "https")
req.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
}),
}
var HideXForwarded = &Middleware{
before: Rewrite(delXForwarded),
}
func delXForwarded(req *Request) {
req.Header.Del("Forwarded")
toRemove := make([]string, 0)
for k := range req.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
toRemove = append(toRemove, k)
before: Rewrite(func(req *Request) {
for k := range req.Header {
if strings.HasPrefix(k, "X-Forwarded-") {
req.Header.Del(k)
}
}
}
for _, k := range toRemove {
req.Header.Del(k)
}
}),
}

View file

@ -34,6 +34,10 @@ func NewModifyResponseWriter(w http.ResponseWriter, r *http.Request, f ModifyRes
}
}
func (w *ModifyResponseWriter) Unwrap() http.ResponseWriter {
return w.w
}
func (w *ModifyResponseWriter) WriteHeader(code int) {
if w.headerSent {
return

View file

@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() {
metrics.GetRouteMetrics().UnregisterService(p.TargetName)
}
func rewriteRequestURL(req *http.Request, target *url.URL) {
func rewriteRequestURL(req *http.Request, target types.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
@ -346,7 +346,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
rewriteRequestURL(outreq, p.TargetURL.URL)
rewriteRequestURL(outreq, p.TargetURL)
outreq.Close = false
reqUpType := UpgradeType(outreq.Header)
@ -355,6 +355,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
return
}
req.Header.Del("Forwarded")
RemoveHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
@ -386,14 +387,19 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
outreq.Header.Set("X-Forwarded-For", clientIP)
}
}
if req.TLS == nil {
outreq.Header.Set("X-Forwarded-Proto", "http")
outreq.Header.Set("X-Forwarded-Scheme", "http")
var reqScheme string
if req.TLS != nil {
reqScheme = "https"
} else {
outreq.Header.Set("X-Forwarded-Proto", "https")
outreq.Header.Set("X-Forwarded-Scheme", "https")
reqScheme = "http"
}
outreq.Header.Set("X-Forwarded-Host", req.Host)
outreq.Header.Set(HeaderXForwardedMethod, req.Method)
outreq.Header.Set(HeaderXForwardedProto, reqScheme)
outreq.Header.Set(HeaderXForwardedHost, req.Host)
outreq.Header.Set(HeaderXForwardedURI, req.RequestURI)
outreq.Header.Set("Origin", reqScheme+"://"+req.Host)
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,