fix modifyResponse middleware incorrect variable substitution

This commit is contained in:
yusing 2024-12-05 10:31:48 +08:00
parent a9f6c4eb20
commit aff8a3b401
10 changed files with 255 additions and 104 deletions

View file

@ -137,13 +137,13 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
return
}
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *Response) error {
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
}
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}

View file

@ -16,14 +16,14 @@ type (
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
Request = http.Request
Response = http.Response
Response = gphttp.ProxyResponse
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc = gphttp.ModifyResponseFunc
ModifyResponseFunc func(*Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any
@ -116,7 +116,9 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.modifyResponse != nil {
w = gphttp.NewModifyResponseWriter(w, r, m.modifyResponse)
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
})
}
if m.before != nil {
m.before(next, w, r)
@ -176,7 +178,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
if mid.before != nil {
ori := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
mid.before(ori, w, r)
}
}
@ -184,7 +186,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
if mid.modifyResponse != nil {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *http.Response) error {
rp.ModifyResponse = func(res *Response) error {
if err := mid.modifyResponse(res); err != nil {
return err
}

View file

@ -10,30 +10,30 @@ import (
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyRequest(t *testing.T) {
func TestModifyRequest(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{
"User-Agent": "go-proxy/v0.5.0",
"Host": "$upstream_addr",
"X-Test-Req-Method": "$req_method",
"X-Test-Req-Scheme": "$req_scheme",
"X-Test-Req-Host": "$req_host",
"X-Test-Req-Port": "$req_port",
"X-Test-Req-Addr": "$req_addr",
"X-Test-Req-Path": "$req_path",
"X-Test-Req-Query": "$req_query",
"X-Test-Req-Url": "$req_url",
"X-Test-Req-Uri": "$req_uri",
"X-Test-Req-Content-Type": "$req_content_type",
"X-Test-Req-Content-Length": "$req_content_length",
"X-Test-Remote-Addr": "$remote_addr",
"X-Test-Upstream-Scheme": "$upstream_scheme",
"X-Test-Upstream-Host": "$upstream_host",
"X-Test-Upstream-Port": "$upstream_port",
"X-Test-Upstream-Addr": "$upstream_addr",
"X-Test-Upstream-Url": "$upstream_url",
"X-Test-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
"User-Agent": "go-proxy/v0.5.0",
"Host": VarUpstreamAddr,
"X-Test-Req-Method": VarRequestMethod,
"X-Test-Req-Scheme": VarRequestScheme,
"X-Test-Req-Host": VarRequestHost,
"X-Test-Req-Port": VarRequestPort,
"X-Test-Req-Addr": VarRequestAddr,
"X-Test-Req-Path": VarRequestPath,
"X-Test-Req-Query": VarRequestQuery,
"X-Test-Req-Url": VarRequestURL,
"X-Test-Req-Uri": VarRequestURI,
"X-Test-Req-Content-Type": VarRequestContentType,
"X-Test-Req-Content-Length": VarRequestContentLen,
"X-Test-Remote-Addr": VarRemoteAddr,
"X-Test-Upstream-Scheme": VarUpstreamScheme,
"X-Test-Upstream-Host": VarUpstreamHost,
"X-Test-Upstream-Port": VarUpstreamPort,
"X-Test-Upstream-Addr": VarUpstreamAddr,
"X-Test-Upstream-Url": VarUpstreamURL,
"X-Test-Header-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
@ -84,7 +84,7 @@ func TestSetModifyRequest(t *testing.T) {
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
})

View file

@ -1,6 +1,8 @@
package middleware
import (
"net/http"
E "github.com/yusing/go-proxy/internal/error"
)
@ -12,10 +14,13 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse)
mr.m = &Middleware{
impl: mr,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
next(w, r)
},
modifyResponse: func(resp *Response) error {
mr.m.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp)
mr.m.AddTraceResponse("before modify response", resp.Response)
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp.Response)
return nil
},
}

View file

@ -1,15 +1,41 @@
package middleware
import (
"bytes"
"net/http"
"slices"
"testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetModifyResponse(t *testing.T) {
func TestModifyResponse(t *testing.T) {
opts := OptionsRaw{
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
"set_headers": map[string]string{
"X-Test-Resp-Status": VarRespStatusCode,
"X-Test-Resp-Content-Type": VarRespContentType,
"X-Test-Resp-Content-Length": VarRespContentLen,
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
"X-Test-Req-Method": VarRequestMethod,
"X-Test-Req-Scheme": VarRequestScheme,
"X-Test-Req-Host": VarRequestHost,
"X-Test-Req-Port": VarRequestPort,
"X-Test-Req-Addr": VarRequestAddr,
"X-Test-Req-Path": VarRequestPath,
"X-Test-Req-Query": VarRequestQuery,
"X-Test-Req-Url": VarRequestURL,
"X-Test-Req-Uri": VarRequestURI,
"X-Test-Remote-Addr": VarRemoteAddr,
"X-Test-Upstream-Scheme": VarUpstreamScheme,
"X-Test-Upstream-Host": VarUpstreamHost,
"X-Test-Upstream-Port": VarUpstreamPort,
"X-Test-Upstream-Addr": VarUpstreamAddr,
"X-Test-Upstream-Url": VarUpstreamURL,
"X-Test-Header-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
},
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"},
}
@ -22,14 +48,50 @@ func TestSetModifyResponse(t *testing.T) {
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
})
t.Run("request_headers", func(t *testing.T) {
t.Run("response_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
respHeaders: http.Header{
"Content-Type": []string{"application/json"},
},
respBody: bytes.Repeat([]byte("a"), 50),
respStatus: http.StatusOK,
})
ExpectNoError(t, err)
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
})
}

View file

@ -34,24 +34,39 @@ func init() {
}
type requestRecorder struct {
args *testArgs
parent http.RoundTripper
headers http.Header
remoteAddr string
}
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
func newRequestRecorder(args *testArgs) *requestRecorder {
return &requestRecorder{args: args}
}
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
rt.headers = req.Header
rt.remoteAddr = req.RemoteAddr
if rt.parent != nil {
return rt.parent.RoundTrip(req)
resp, err = rt.parent.RoundTrip(req)
} else {
resp = &http.Response{
Status: http.StatusText(rt.args.respStatus),
StatusCode: rt.args.respStatus,
Header: testHeaders,
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
ContentLength: int64(len(rt.args.respBody)),
Request: req,
TLS: req.TLS,
}
}
return &http.Response{
StatusCode: http.StatusOK,
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
TLS: req.TLS,
}, nil
if err == nil {
for k, v := range rt.args.respHeaders {
resp.Header[k] = v
}
}
return resp, nil
}
type TestResult struct {
@ -64,56 +79,84 @@ type TestResult struct {
type testArgs struct {
middlewareOpt OptionsRaw
reqURL types.URL
upstreamURL types.URL
body []byte
realRoundTrip bool
headers http.Header
reqURL types.URL
reqMethod string
headers http.Header
body []byte
respHeaders http.Header
respBody []byte
respStatus int
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
var body io.Reader
var rr requestRecorder
var err error
if args == nil {
args = new(testArgs)
}
if args.body != nil {
body = bytes.NewReader(args.body)
}
func (args *testArgs) setDefaults() {
if args.reqURL.Nil() {
args.reqURL = E.Must(types.ParseURL("https://example.com"))
}
req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body)
for k, v := range args.headers {
req.Header[k] = v
if args.reqMethod == "" {
args.reqMethod = http.MethodGet
}
w := httptest.NewRecorder()
if args.upstreamURL.Nil() {
args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
}
if args.respHeaders == nil {
args.respHeaders = http.Header{}
}
if args.respBody == nil {
args.respBody = []byte("OK")
}
if args.respStatus == 0 {
args.respStatus = http.StatusOK
}
}
func (args *testArgs) bodyReader() io.Reader {
if args.body != nil {
return bytes.NewReader(args.body)
}
return nil
}
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
if args == nil {
args = new(testArgs)
}
args.setDefaults()
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
for k, v := range args.headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
rr := newRequestRecorder(args)
if args.realRoundTrip {
rr.parent = http.DefaultTransport
}
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr)
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
patchReverseProxy(rp, []*Middleware{mid})
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rr.headers,
ResponseHeaders: resp.Header,

View file

@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"net/http"
"sync"
"time"
@ -42,7 +43,7 @@ func (tr *Trace) WithRequest(req *Request) *Trace {
return tr
}
func (tr *Trace) WithResponse(resp *Response) *Trace {
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
if tr == nil {
return nil
}
@ -103,7 +104,7 @@ func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
if !m.trace {
return nil
}

View file

@ -22,32 +22,61 @@ var (
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
const (
VarRequestMethod = "$req_method"
VarRequestScheme = "$req_scheme"
VarRequestHost = "$req_host"
VarRequestPort = "$req_port"
VarRequestPath = "$req_path"
VarRequestAddr = "$req_addr"
VarRequestQuery = "$req_query"
VarRequestURL = "$req_url"
VarRequestURI = "$req_uri"
VarRequestContentType = "$req_content_type"
VarRequestContentLen = "$req_content_length"
VarRemoteAddr = "$remote_addr"
VarUpstreamScheme = "$upstream_scheme"
VarUpstreamHost = "$upstream_host"
VarUpstreamPort = "$upstream_port"
VarUpstreamAddr = "$upstream_addr"
VarUpstreamURL = "$upstream_url"
VarRespContentType = "$resp_content_type"
VarRespContentLen = "$resp_content_length"
VarRespStatusCode = "$status_code"
)
var staticReqVarSubsMap = map[string]reqVarGetter{
"$req_method": func(req *Request) string { return req.Method },
"$req_scheme": func(req *Request) string { return req.URL.Scheme },
"$req_host": func(req *Request) string {
VarRequestMethod: func(req *Request) string { return req.Method },
VarRequestScheme: func(req *Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
"$req_port": func(req *Request) string {
VarRequestPort: func(req *Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
"$req_addr": func(req *Request) string { return req.Host },
"$req_path": func(req *Request) string { return req.URL.Path },
"$req_query": func(req *Request) string { return req.URL.RawQuery },
"$req_url": func(req *Request) string { return req.URL.String() },
"$req_uri": func(req *Request) string { return req.URL.RequestURI() },
"$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") },
"$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
"$remote_addr": func(req *Request) string { return req.RemoteAddr },
"$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
"$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
"$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
"$upstream_addr": func(req *Request) string {
VarRequestAddr: func(req *Request) string { return req.Host },
VarRequestPath: func(req *Request) string { return req.URL.Path },
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *Request) string { return req.URL.String() },
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
VarUpstreamAddr: func(req *Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
if upPort != "" {
@ -55,7 +84,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
return upHost
},
"$upstream_url": func(req *Request) string {
VarUpstreamURL: func(req *Request) string {
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" {
return ""
@ -71,9 +100,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
}
var staticRespVarSubsMap = map[string]respVarGetter{
"$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") },
"$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") },
"$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *Request, resp *Response, s string) string {
@ -99,7 +128,7 @@ func varReplace(req *Request, resp *Response, s string) string {
if resp != nil {
// Replace response headers
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[14 : len(match)-1])
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
return resp.Header.Get(header)
})
}

View file

@ -0,0 +1,8 @@
package http
import "net/http"
type ProxyResponse struct {
*http.Response
OriginalRequest *http.Request
}

View file

@ -87,7 +87,7 @@ type ReverseProxy struct {
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error
ModifyResponse func(*ProxyResponse) error
HandlerFunc http.HandlerFunc
@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() {
metrics.GetRouteMetrics().UnregisterService(p.TargetName)
}
func rewriteRequestURL(req *http.Request, target types.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL)
func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
targetQuery := p.TargetURL.RawQuery
req.URL.Scheme = p.TargetURL.Scheme
req.URL.Host = p.TargetURL.Host
req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL.URL, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
@ -251,11 +251,11 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(res); err != nil {
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
res.Body.Close()
p.errorHandler(rw, req, err, true)
return false
@ -349,7 +349,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
rewriteRequestURL(outreq, p.TargetURL)
p.rewriteRequestURL(outreq)
outreq.Close = false
reqUpType := UpgradeType(outreq.Header)
@ -439,6 +439,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
roundTripMutex.Lock()
roundTripDone = true
roundTripMutex.Unlock()
@ -459,7 +460,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, outreq) {
if !p.modifyResponse(rw, res, req, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
@ -468,7 +469,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
RemoveHopByHopHeaders(res.Header)
if !p.modifyResponse(rw, res, outreq) {
if !p.modifyResponse(rw, res, req, outreq) {
return
}