From 2cbee10527b7b293377c9f32a4ebeb3c127c006b Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 5 Dec 2024 10:37:17 +0800 Subject: [PATCH] add $remote_host and $remote_port variables --- .../http/middleware/modify_request_test.go | 7 ++++++ .../http/middleware/modify_response_test.go | 11 +++++++++ internal/net/http/middleware/vars.go | 24 +++++++++++++++---- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 704f91b..422c53d 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "net" "net/http" "slices" "testing" @@ -26,6 +27,8 @@ func TestModifyRequest(t *testing.T) { "X-Test-Req-Uri": VarRequestURI, "X-Test-Req-Content-Type": VarRequestContentType, "X-Test-Req-Content-Length": VarRequestContentLen, + "X-Test-Remote-Host": VarRemoteHost, + "X-Test-Remote-Port": VarRemotePort, "X-Test-Remote-Addr": VarRemoteAddr, "X-Test-Upstream-Scheme": VarUpstreamScheme, "X-Test-Upstream-Host": VarUpstreamHost, @@ -76,6 +79,10 @@ func TestModifyRequest(t *testing.T) { ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json") ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100") + + remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Host"), remoteHost) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Port"), remotePort) ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 7da4c55..0e7ef15 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "net" "net/http" "slices" "testing" @@ -27,6 +28,10 @@ func TestModifyResponse(t *testing.T) { "X-Test-Req-Query": VarRequestQuery, "X-Test-Req-Url": VarRequestURL, "X-Test-Req-Uri": VarRequestURI, + "X-Test-Req-Content-Type": VarRequestContentType, + "X-Test-Req-Content-Length": VarRequestContentLen, + "X-Test-Remote-Host": VarRemoteHost, + "X-Test-Remote-Port": VarRemotePort, "X-Test-Remote-Addr": VarRemoteAddr, "X-Test-Upstream-Scheme": VarUpstreamScheme, "X-Test-Upstream-Host": VarUpstreamHost, @@ -83,6 +88,12 @@ func TestModifyResponse(t *testing.T) { ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String()) ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Type"), "application/json") + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Content-Length"), "100") + + remoteHost, remotePort, _ := net.SplitHostPort(result.RemoteAddr) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Host"), remoteHost) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Port"), remotePort) ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) diff --git a/internal/net/http/middleware/vars.go b/internal/net/http/middleware/vars.go index 1bf1bbf..106293f 100644 --- a/internal/net/http/middleware/vars.go +++ b/internal/net/http/middleware/vars.go @@ -34,6 +34,8 @@ const ( VarRequestURI = "$req_uri" VarRequestContentType = "$req_content_type" VarRequestContentLen = "$req_content_length" + VarRemoteHost = "$remote_host" + VarRemotePort = "$remote_port" VarRemoteAddr = "$remote_addr" VarUpstreamScheme = "$upstream_scheme" VarUpstreamHost = "$upstream_host" @@ -72,10 +74,24 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ VarRequestURI: func(req *Request) string { return req.URL.RequestURI() }, VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") }, VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) }, - VarRemoteAddr: func(req *Request) string { return req.RemoteAddr }, - VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, - VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, - VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, + VarRemoteHost: func(req *Request) string { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + return clientIP + } + return "" + }, + VarRemotePort: func(req *Request) string { + _, clientPort, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + return clientPort + } + return "" + }, + VarRemoteAddr: func(req *Request) string { return req.RemoteAddr }, + VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, + VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, + VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, VarUpstreamAddr: func(req *Request) string { upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upPort := req.Header.Get(gphttp.HeaderUpstreamPort)