From aff8a3b401a20b2f39d62bb08e34129779e37b05 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 5 Dec 2024 10:31:48 +0800 Subject: [PATCH] fix modifyResponse middleware incorrect variable substitution --- internal/net/http/middleware/forward_auth.go | 4 +- internal/net/http/middleware/middleware.go | 12 +- .../http/middleware/modify_request_test.go | 46 ++++---- .../net/http/middleware/modify_response.go | 11 +- .../http/middleware/modify_response_test.go | 72 +++++++++++- internal/net/http/middleware/test_utils.go | 107 ++++++++++++------ internal/net/http/middleware/trace.go | 5 +- internal/net/http/middleware/vars.go | 71 ++++++++---- internal/net/http/proxy_response.go | 8 ++ internal/net/http/reverse_proxy_mod.go | 23 ++-- 10 files changed, 255 insertions(+), 104 deletions(-) create mode 100644 internal/net/http/proxy_response.go diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 28ed657..6a038d5 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -137,13 +137,13 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req return } - next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *Response) error { + next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error { fa.setAuthCookies(resp, authCookies) return nil }), req) } -func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) { +func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) { if len(fa.AddAuthCookiesToResponse) == 0 { return } diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index c6facf8..4a01dad 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -16,14 +16,14 @@ type ( ReverseProxy = gphttp.ReverseProxy ProxyRequest = gphttp.ProxyRequest Request = http.Request - Response = http.Response + Response = gphttp.ProxyResponse ResponseWriter = http.ResponseWriter Header = http.Header Cookie = http.Cookie BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request) RewriteFunc func(req *Request) - ModifyResponseFunc = gphttp.ModifyResponseFunc + ModifyResponseFunc func(*Response) error CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error) OptionsRaw = map[string]any @@ -116,7 +116,9 @@ func (m *Middleware) ModifyResponse(resp *Response) error { func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) { if m.modifyResponse != nil { - w = gphttp.NewModifyResponseWriter(w, r, m.modifyResponse) + w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error { + return m.modifyResponse(&Response{Response: resp, OriginalRequest: r}) + }) } if m.before != nil { m.before(next, w, r) @@ -176,7 +178,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { if mid.before != nil { ori := rp.HandlerFunc - rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + rp.HandlerFunc = func(w http.ResponseWriter, r *Request) { mid.before(ori, w, r) } } @@ -184,7 +186,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { if mid.modifyResponse != nil { if rp.ModifyResponse != nil { ori := rp.ModifyResponse - rp.ModifyResponse = func(res *http.Response) error { + rp.ModifyResponse = func(res *Response) error { if err := mid.modifyResponse(res); err != nil { return err } diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 83aade1..704f91b 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -10,30 +10,30 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestSetModifyRequest(t *testing.T) { +func TestModifyRequest(t *testing.T) { opts := OptionsRaw{ "set_headers": map[string]string{ - "User-Agent": "go-proxy/v0.5.0", - "Host": "$upstream_addr", - "X-Test-Req-Method": "$req_method", - "X-Test-Req-Scheme": "$req_scheme", - "X-Test-Req-Host": "$req_host", - "X-Test-Req-Port": "$req_port", - "X-Test-Req-Addr": "$req_addr", - "X-Test-Req-Path": "$req_path", - "X-Test-Req-Query": "$req_query", - "X-Test-Req-Url": "$req_url", - "X-Test-Req-Uri": "$req_uri", - "X-Test-Req-Content-Type": "$req_content_type", - "X-Test-Req-Content-Length": "$req_content_length", - "X-Test-Remote-Addr": "$remote_addr", - "X-Test-Upstream-Scheme": "$upstream_scheme", - "X-Test-Upstream-Host": "$upstream_host", - "X-Test-Upstream-Port": "$upstream_port", - "X-Test-Upstream-Addr": "$upstream_addr", - "X-Test-Upstream-Url": "$upstream_url", - "X-Test-Content-Type": "$header(Content-Type)", - "X-Test-Arg-Arg_1": "$arg(arg_1)", + "User-Agent": "go-proxy/v0.5.0", + "Host": VarUpstreamAddr, + "X-Test-Req-Method": VarRequestMethod, + "X-Test-Req-Scheme": VarRequestScheme, + "X-Test-Req-Host": VarRequestHost, + "X-Test-Req-Port": VarRequestPort, + "X-Test-Req-Addr": VarRequestAddr, + "X-Test-Req-Path": VarRequestPath, + "X-Test-Req-Query": VarRequestQuery, + "X-Test-Req-Url": VarRequestURL, + "X-Test-Req-Uri": VarRequestURI, + "X-Test-Req-Content-Type": VarRequestContentType, + "X-Test-Req-Content-Length": VarRequestContentLen, + "X-Test-Remote-Addr": VarRemoteAddr, + "X-Test-Upstream-Scheme": VarUpstreamScheme, + "X-Test-Upstream-Host": VarUpstreamHost, + "X-Test-Upstream-Port": VarUpstreamPort, + "X-Test-Upstream-Addr": VarUpstreamAddr, + "X-Test-Upstream-Url": VarUpstreamURL, + "X-Test-Header-Content-Type": "$header(Content-Type)", + "X-Test-Arg-Arg_1": "$arg(arg_1)", }, "add_headers": map[string]string{"Accept-Encoding": "test-value"}, "hide_headers": []string{"Accept"}, @@ -84,7 +84,7 @@ func TestSetModifyRequest(t *testing.T) { ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) - ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json") + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json") ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b") }) diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index 6ba7fe7..b5de559 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -1,6 +1,8 @@ package middleware import ( + "net/http" + E "github.com/yusing/go-proxy/internal/error" ) @@ -12,10 +14,13 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyResponse) mr.m = &Middleware{ impl: mr, + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { + next(w, r) + }, modifyResponse: func(resp *Response) error { - mr.m.AddTraceResponse("before modify response", resp) - mr.modifyHeaders(resp.Request, resp, resp.Header) - mr.m.AddTraceResponse("after modify response", resp) + mr.m.AddTraceResponse("before modify response", resp.Response) + mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header) + mr.m.AddTraceResponse("after modify response", resp.Response) return nil }, } diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 2672b5d..7da4c55 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -1,15 +1,41 @@ package middleware import ( + "bytes" + "net/http" "slices" "testing" + "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestSetModifyResponse(t *testing.T) { +func TestModifyResponse(t *testing.T) { opts := OptionsRaw{ - "set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"}, + "set_headers": map[string]string{ + "X-Test-Resp-Status": VarRespStatusCode, + "X-Test-Resp-Content-Type": VarRespContentType, + "X-Test-Resp-Content-Length": VarRespContentLen, + "X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)", + + "X-Test-Req-Method": VarRequestMethod, + "X-Test-Req-Scheme": VarRequestScheme, + "X-Test-Req-Host": VarRequestHost, + "X-Test-Req-Port": VarRequestPort, + "X-Test-Req-Addr": VarRequestAddr, + "X-Test-Req-Path": VarRequestPath, + "X-Test-Req-Query": VarRequestQuery, + "X-Test-Req-Url": VarRequestURL, + "X-Test-Req-Uri": VarRequestURI, + "X-Test-Remote-Addr": VarRemoteAddr, + "X-Test-Upstream-Scheme": VarUpstreamScheme, + "X-Test-Upstream-Host": VarUpstreamHost, + "X-Test-Upstream-Port": VarUpstreamPort, + "X-Test-Upstream-Addr": VarUpstreamAddr, + "X-Test-Upstream-Url": VarUpstreamURL, + "X-Test-Header-Content-Type": "$header(Content-Type)", + "X-Test-Arg-Arg_1": "$arg(arg_1)", + }, "add_headers": map[string]string{"Accept-Encoding": "test-value"}, "hide_headers": []string{"Accept"}, } @@ -22,14 +48,50 @@ func TestSetModifyResponse(t *testing.T) { ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string)) }) - t.Run("request_headers", func(t *testing.T) { + t.Run("response_headers", func(t *testing.T) { + reqURL := types.MustParseURL("https://my.app/?arg_1=b") + upstreamURL := types.MustParseURL("http://test.example.com") result, err := newMiddlewareTest(ModifyResponse, &testArgs{ middlewareOpt: opts, + reqURL: reqURL, + upstreamURL: upstreamURL, + body: bytes.Repeat([]byte("a"), 100), + headers: http.Header{ + "Content-Type": []string{"application/json"}, + }, + respHeaders: http.Header{ + "Content-Type": []string{"application/json"}, + }, + respBody: bytes.Repeat([]byte("a"), 50), + respStatus: http.StatusOK, }) ExpectNoError(t, err) - ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0") - t.Log(result.ResponseHeaders.Get("Accept-Encoding")) ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value")) ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "") + + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200") + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json") + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50") + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json") + + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) + + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port()) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) + + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json") + ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b") }) } diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index 858fe60..d7f292b 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -34,24 +34,39 @@ func init() { } type requestRecorder struct { + args *testArgs + parent http.RoundTripper headers http.Header remoteAddr string } -func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) { +func newRequestRecorder(args *testArgs) *requestRecorder { + return &requestRecorder{args: args} +} + +func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) { rt.headers = req.Header rt.remoteAddr = req.RemoteAddr if rt.parent != nil { - return rt.parent.RoundTrip(req) + resp, err = rt.parent.RoundTrip(req) + } else { + resp = &http.Response{ + Status: http.StatusText(rt.args.respStatus), + StatusCode: rt.args.respStatus, + Header: testHeaders, + Body: io.NopCloser(bytes.NewReader(rt.args.respBody)), + ContentLength: int64(len(rt.args.respBody)), + Request: req, + TLS: req.TLS, + } } - return &http.Response{ - StatusCode: http.StatusOK, - Header: testHeaders, - Body: io.NopCloser(bytes.NewBufferString("OK")), - Request: req, - TLS: req.TLS, - }, nil + if err == nil { + for k, v := range rt.args.respHeaders { + resp.Header[k] = v + } + } + return resp, nil } type TestResult struct { @@ -64,56 +79,84 @@ type TestResult struct { type testArgs struct { middlewareOpt OptionsRaw - reqURL types.URL upstreamURL types.URL - body []byte + realRoundTrip bool - headers http.Header + + reqURL types.URL + reqMethod string + headers http.Header + body []byte + + respHeaders http.Header + respBody []byte + respStatus int } -func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { - var body io.Reader - var rr requestRecorder - var err error - - if args == nil { - args = new(testArgs) - } - - if args.body != nil { - body = bytes.NewReader(args.body) - } - +func (args *testArgs) setDefaults() { if args.reqURL.Nil() { args.reqURL = E.Must(types.ParseURL("https://example.com")) } - - req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body) - for k, v := range args.headers { - req.Header[k] = v + if args.reqMethod == "" { + args.reqMethod = http.MethodGet } - w := httptest.NewRecorder() - if args.upstreamURL.Nil() { args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect } + if args.respHeaders == nil { + args.respHeaders = http.Header{} + } + if args.respBody == nil { + args.respBody = []byte("OK") + } + if args.respStatus == 0 { + args.respStatus = http.StatusOK + } +} +func (args *testArgs) bodyReader() io.Reader { + if args.body != nil { + return bytes.NewReader(args.body) + } + return nil +} + +func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { + if args == nil { + args = new(testArgs) + } + args.setDefaults() + + req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader()) + for k, v := range args.headers { + req.Header[k] = v + } + + w := httptest.NewRecorder() + + rr := newRequestRecorder(args) if args.realRoundTrip { rr.parent = http.DefaultTransport } - rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr) + + rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr) + mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr } + patchReverseProxy(rp, []*Middleware{mid}) rp.ServeHTTP(w, req) + resp := w.Result() defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) if err != nil { return nil, E.From(err) } + return &TestResult{ RequestHeaders: rr.headers, ResponseHeaders: resp.Header, diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index 1a8e8e8..5c46464 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net/http" "sync" "time" @@ -42,7 +43,7 @@ func (tr *Trace) WithRequest(req *Request) *Trace { return tr } -func (tr *Trace) WithResponse(resp *Response) *Trace { +func (tr *Trace) WithResponse(resp *http.Response) *Trace { if tr == nil { return nil } @@ -103,7 +104,7 @@ func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace { return m.AddTracef("%s", msg).WithRequest(req) } -func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace { +func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace { if !m.trace { return nil } diff --git a/internal/net/http/middleware/vars.go b/internal/net/http/middleware/vars.go index 68f5f67..1bf1bbf 100644 --- a/internal/net/http/middleware/vars.go +++ b/internal/net/http/middleware/vars.go @@ -22,32 +22,61 @@ var ( reStatic = regexp.MustCompile(`\$[\w_]+`) ) +const ( + VarRequestMethod = "$req_method" + VarRequestScheme = "$req_scheme" + VarRequestHost = "$req_host" + VarRequestPort = "$req_port" + VarRequestPath = "$req_path" + VarRequestAddr = "$req_addr" + VarRequestQuery = "$req_query" + VarRequestURL = "$req_url" + VarRequestURI = "$req_uri" + VarRequestContentType = "$req_content_type" + VarRequestContentLen = "$req_content_length" + VarRemoteAddr = "$remote_addr" + VarUpstreamScheme = "$upstream_scheme" + VarUpstreamHost = "$upstream_host" + VarUpstreamPort = "$upstream_port" + VarUpstreamAddr = "$upstream_addr" + VarUpstreamURL = "$upstream_url" + + VarRespContentType = "$resp_content_type" + VarRespContentLen = "$resp_content_length" + VarRespStatusCode = "$status_code" +) + var staticReqVarSubsMap = map[string]reqVarGetter{ - "$req_method": func(req *Request) string { return req.Method }, - "$req_scheme": func(req *Request) string { return req.URL.Scheme }, - "$req_host": func(req *Request) string { + VarRequestMethod: func(req *Request) string { return req.Method }, + VarRequestScheme: func(req *Request) string { + if req.TLS != nil { + return "https" + } + return "http" + }, + VarRequestHost: func(req *Request) string { reqHost, _, err := net.SplitHostPort(req.Host) if err != nil { return req.Host } return reqHost }, - "$req_port": func(req *Request) string { + VarRequestPort: func(req *Request) string { _, reqPort, _ := net.SplitHostPort(req.Host) return reqPort }, - "$req_addr": func(req *Request) string { return req.Host }, - "$req_path": func(req *Request) string { return req.URL.Path }, - "$req_query": func(req *Request) string { return req.URL.RawQuery }, - "$req_url": func(req *Request) string { return req.URL.String() }, - "$req_uri": func(req *Request) string { return req.URL.RequestURI() }, - "$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") }, - "$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) }, - "$remote_addr": func(req *Request) string { return req.RemoteAddr }, - "$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, - "$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, - "$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, - "$upstream_addr": func(req *Request) string { + VarRequestAddr: func(req *Request) string { return req.Host }, + VarRequestPath: func(req *Request) string { return req.URL.Path }, + VarRequestQuery: func(req *Request) string { return req.URL.RawQuery }, + VarRequestURL: func(req *Request) string { return req.URL.String() }, + VarRequestURI: func(req *Request) string { return req.URL.RequestURI() }, + VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") }, + VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) }, + VarRemoteAddr: func(req *Request) string { return req.RemoteAddr }, + VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, + VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, + VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, + VarUpstreamAddr: func(req *Request) string { upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upPort := req.Header.Get(gphttp.HeaderUpstreamPort) if upPort != "" { @@ -55,7 +84,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ } return upHost }, - "$upstream_url": func(req *Request) string { + VarUpstreamURL: func(req *Request) string { upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) if upScheme == "" { return "" @@ -71,9 +100,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ } var staticRespVarSubsMap = map[string]respVarGetter{ - "$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") }, - "$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") }, - "$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) }, + VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") }, + VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) }, + VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) }, } func varReplace(req *Request, resp *Response, s string) string { @@ -99,7 +128,7 @@ func varReplace(req *Request, resp *Response, s string) string { if resp != nil { // Replace response headers s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string { - header := http.CanonicalHeaderKey(match[14 : len(match)-1]) + header := http.CanonicalHeaderKey(match[13 : len(match)-1]) return resp.Header.Get(header) }) } diff --git a/internal/net/http/proxy_response.go b/internal/net/http/proxy_response.go new file mode 100644 index 0000000..7a5c87c --- /dev/null +++ b/internal/net/http/proxy_response.go @@ -0,0 +1,8 @@ +package http + +import "net/http" + +type ProxyResponse struct { + *http.Response + OriginalRequest *http.Request +} diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index 15ed27a..136f9d3 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -87,7 +87,7 @@ type ReverseProxy struct { // If ModifyResponse returns an error, ErrorHandler is called // with its error value. If ErrorHandler is nil, its default // implementation is used. - ModifyResponse func(*http.Response) error + ModifyResponse func(*ProxyResponse) error HandlerFunc http.HandlerFunc @@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() { metrics.GetRouteMetrics().UnregisterService(p.TargetName) } -func rewriteRequestURL(req *http.Request, target types.URL) { - targetQuery := target.RawQuery - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL) +func (p *ReverseProxy) rewriteRequestURL(req *http.Request) { + targetQuery := p.TargetURL.RawQuery + req.URL.Scheme = p.TargetURL.Scheme + req.URL.Host = p.TargetURL.Host + req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL.URL, req.URL) if targetQuery == "" || req.URL.RawQuery == "" { req.URL.RawQuery = targetQuery + req.URL.RawQuery } else { @@ -251,11 +251,11 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err // modifyResponse conditionally runs the optional ModifyResponse hook // and reports whether the request should proceed. -func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool { if p.ModifyResponse == nil { return true } - if err := p.ModifyResponse(res); err != nil { + if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil { res.Body.Close() p.errorHandler(rw, req, err, true) return false @@ -349,7 +349,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate } - rewriteRequestURL(outreq, p.TargetURL) + p.rewriteRequestURL(outreq) outreq.Close = false reqUpType := UpgradeType(outreq.Header) @@ -439,6 +439,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) res, err := transport.RoundTrip(outreq) + roundTripMutex.Lock() roundTripDone = true roundTripMutex.Unlock() @@ -459,7 +460,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) if res.StatusCode == http.StatusSwitchingProtocols { - if !p.modifyResponse(rw, res, outreq) { + if !p.modifyResponse(rw, res, req, outreq) { return } p.handleUpgradeResponse(rw, outreq, res) @@ -468,7 +469,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { RemoveHopByHopHeaders(res.Header) - if !p.modifyResponse(rw, res, outreq) { + if !p.modifyResponse(rw, res, req, outreq) { return }