mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix modifyResponse middleware incorrect variable substitution
This commit is contained in:
parent
a9f6c4eb20
commit
aff8a3b401
10 changed files with 255 additions and 104 deletions
|
@ -137,13 +137,13 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
|||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *Response) error {
|
||||
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
|
||||
fa.setAuthCookies(resp, authCookies)
|
||||
return nil
|
||||
}), req)
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
|
||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
|
||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -16,14 +16,14 @@ type (
|
|||
ReverseProxy = gphttp.ReverseProxy
|
||||
ProxyRequest = gphttp.ProxyRequest
|
||||
Request = http.Request
|
||||
Response = http.Response
|
||||
Response = gphttp.ProxyResponse
|
||||
ResponseWriter = http.ResponseWriter
|
||||
Header = http.Header
|
||||
Cookie = http.Cookie
|
||||
|
||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *Request)
|
||||
ModifyResponseFunc = gphttp.ModifyResponseFunc
|
||||
ModifyResponseFunc func(*Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
||||
|
||||
OptionsRaw = map[string]any
|
||||
|
@ -116,7 +116,9 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
|
|||
|
||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if m.modifyResponse != nil {
|
||||
w = gphttp.NewModifyResponseWriter(w, r, m.modifyResponse)
|
||||
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
|
||||
})
|
||||
}
|
||||
if m.before != nil {
|
||||
m.before(next, w, r)
|
||||
|
@ -176,7 +178,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
|||
|
||||
if mid.before != nil {
|
||||
ori := rp.HandlerFunc
|
||||
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
||||
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
|
||||
mid.before(ori, w, r)
|
||||
}
|
||||
}
|
||||
|
@ -184,7 +186,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
|||
if mid.modifyResponse != nil {
|
||||
if rp.ModifyResponse != nil {
|
||||
ori := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *http.Response) error {
|
||||
rp.ModifyResponse = func(res *Response) error {
|
||||
if err := mid.modifyResponse(res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -10,30 +10,30 @@ import (
|
|||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetModifyRequest(t *testing.T) {
|
||||
func TestModifyRequest(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{
|
||||
"User-Agent": "go-proxy/v0.5.0",
|
||||
"Host": "$upstream_addr",
|
||||
"X-Test-Req-Method": "$req_method",
|
||||
"X-Test-Req-Scheme": "$req_scheme",
|
||||
"X-Test-Req-Host": "$req_host",
|
||||
"X-Test-Req-Port": "$req_port",
|
||||
"X-Test-Req-Addr": "$req_addr",
|
||||
"X-Test-Req-Path": "$req_path",
|
||||
"X-Test-Req-Query": "$req_query",
|
||||
"X-Test-Req-Url": "$req_url",
|
||||
"X-Test-Req-Uri": "$req_uri",
|
||||
"X-Test-Req-Content-Type": "$req_content_type",
|
||||
"X-Test-Req-Content-Length": "$req_content_length",
|
||||
"X-Test-Remote-Addr": "$remote_addr",
|
||||
"X-Test-Upstream-Scheme": "$upstream_scheme",
|
||||
"X-Test-Upstream-Host": "$upstream_host",
|
||||
"X-Test-Upstream-Port": "$upstream_port",
|
||||
"X-Test-Upstream-Addr": "$upstream_addr",
|
||||
"X-Test-Upstream-Url": "$upstream_url",
|
||||
"X-Test-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
"User-Agent": "go-proxy/v0.5.0",
|
||||
"Host": VarUpstreamAddr,
|
||||
"X-Test-Req-Method": VarRequestMethod,
|
||||
"X-Test-Req-Scheme": VarRequestScheme,
|
||||
"X-Test-Req-Host": VarRequestHost,
|
||||
"X-Test-Req-Port": VarRequestPort,
|
||||
"X-Test-Req-Addr": VarRequestAddr,
|
||||
"X-Test-Req-Path": VarRequestPath,
|
||||
"X-Test-Req-Query": VarRequestQuery,
|
||||
"X-Test-Req-Url": VarRequestURL,
|
||||
"X-Test-Req-Uri": VarRequestURI,
|
||||
"X-Test-Req-Content-Type": VarRequestContentType,
|
||||
"X-Test-Req-Content-Length": VarRequestContentLen,
|
||||
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
|
@ -84,7 +84,7 @@ func TestSetModifyRequest(t *testing.T) {
|
|||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
|
@ -12,10 +14,13 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
|||
mr := new(modifyResponse)
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
next(w, r)
|
||||
},
|
||||
modifyResponse: func(resp *Response) error {
|
||||
mr.m.AddTraceResponse("before modify response", resp)
|
||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||
mr.m.AddTraceResponse("after modify response", resp)
|
||||
mr.m.AddTraceResponse("before modify response", resp.Response)
|
||||
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
|
||||
mr.m.AddTraceResponse("after modify response", resp.Response)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,15 +1,41 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetModifyResponse(t *testing.T) {
|
||||
func TestModifyResponse(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
|
||||
"set_headers": map[string]string{
|
||||
"X-Test-Resp-Status": VarRespStatusCode,
|
||||
"X-Test-Resp-Content-Type": VarRespContentType,
|
||||
"X-Test-Resp-Content-Length": VarRespContentLen,
|
||||
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
|
||||
|
||||
"X-Test-Req-Method": VarRequestMethod,
|
||||
"X-Test-Req-Scheme": VarRequestScheme,
|
||||
"X-Test-Req-Host": VarRequestHost,
|
||||
"X-Test-Req-Port": VarRequestPort,
|
||||
"X-Test-Req-Addr": VarRequestAddr,
|
||||
"X-Test-Req-Path": VarRequestPath,
|
||||
"X-Test-Req-Query": VarRequestQuery,
|
||||
"X-Test-Req-Url": VarRequestURL,
|
||||
"X-Test-Req-Uri": VarRequestURI,
|
||||
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
}
|
||||
|
@ -22,14 +48,50 @@ func TestSetModifyResponse(t *testing.T) {
|
|||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
t.Run("response_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
body: bytes.Repeat([]byte("a"), 100),
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respBody: bytes.Repeat([]byte("a"), 50),
|
||||
respStatus: http.StatusOK,
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
|
||||
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -34,24 +34,39 @@ func init() {
|
|||
}
|
||||
|
||||
type requestRecorder struct {
|
||||
args *testArgs
|
||||
|
||||
parent http.RoundTripper
|
||||
headers http.Header
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func newRequestRecorder(args *testArgs) *requestRecorder {
|
||||
return &requestRecorder{args: args}
|
||||
}
|
||||
|
||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||
rt.headers = req.Header
|
||||
rt.remoteAddr = req.RemoteAddr
|
||||
if rt.parent != nil {
|
||||
return rt.parent.RoundTrip(req)
|
||||
resp, err = rt.parent.RoundTrip(req)
|
||||
} else {
|
||||
resp = &http.Response{
|
||||
Status: http.StatusText(rt.args.respStatus),
|
||||
StatusCode: rt.args.respStatus,
|
||||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
|
||||
ContentLength: int64(len(rt.args.respBody)),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}, nil
|
||||
if err == nil {
|
||||
for k, v := range rt.args.respHeaders {
|
||||
resp.Header[k] = v
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
|
@ -64,56 +79,84 @@ type TestResult struct {
|
|||
|
||||
type testArgs struct {
|
||||
middlewareOpt OptionsRaw
|
||||
reqURL types.URL
|
||||
upstreamURL types.URL
|
||||
body []byte
|
||||
|
||||
realRoundTrip bool
|
||||
headers http.Header
|
||||
|
||||
reqURL types.URL
|
||||
reqMethod string
|
||||
headers http.Header
|
||||
body []byte
|
||||
|
||||
respHeaders http.Header
|
||||
respBody []byte
|
||||
respStatus int
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||
var body io.Reader
|
||||
var rr requestRecorder
|
||||
var err error
|
||||
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
|
||||
if args.body != nil {
|
||||
body = bytes.NewReader(args.body)
|
||||
}
|
||||
|
||||
func (args *testArgs) setDefaults() {
|
||||
if args.reqURL.Nil() {
|
||||
args.reqURL = E.Must(types.ParseURL("https://example.com"))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body)
|
||||
for k, v := range args.headers {
|
||||
req.Header[k] = v
|
||||
if args.reqMethod == "" {
|
||||
args.reqMethod = http.MethodGet
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if args.upstreamURL.Nil() {
|
||||
args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||
}
|
||||
if args.respHeaders == nil {
|
||||
args.respHeaders = http.Header{}
|
||||
}
|
||||
if args.respBody == nil {
|
||||
args.respBody = []byte("OK")
|
||||
}
|
||||
if args.respStatus == 0 {
|
||||
args.respStatus = http.StatusOK
|
||||
}
|
||||
}
|
||||
|
||||
func (args *testArgs) bodyReader() io.Reader {
|
||||
if args.body != nil {
|
||||
return bytes.NewReader(args.body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||
if args == nil {
|
||||
args = new(testArgs)
|
||||
}
|
||||
args.setDefaults()
|
||||
|
||||
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
||||
for k, v := range args.headers {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
rr := newRequestRecorder(args)
|
||||
if args.realRoundTrip {
|
||||
rr.parent = http.DefaultTransport
|
||||
}
|
||||
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr)
|
||||
|
||||
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
||||
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
|
||||
patchReverseProxy(rp, []*Middleware{mid})
|
||||
rp.ServeHTTP(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
|
||||
return &TestResult{
|
||||
RequestHeaders: rr.headers,
|
||||
ResponseHeaders: resp.Header,
|
||||
|
|
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -42,7 +43,7 @@ func (tr *Trace) WithRequest(req *Request) *Trace {
|
|||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithResponse(resp *Response) *Trace {
|
||||
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -103,7 +104,7 @@ func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
|||
return m.AddTracef("%s", msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
|
||||
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -22,32 +22,61 @@ var (
|
|||
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
||||
)
|
||||
|
||||
const (
|
||||
VarRequestMethod = "$req_method"
|
||||
VarRequestScheme = "$req_scheme"
|
||||
VarRequestHost = "$req_host"
|
||||
VarRequestPort = "$req_port"
|
||||
VarRequestPath = "$req_path"
|
||||
VarRequestAddr = "$req_addr"
|
||||
VarRequestQuery = "$req_query"
|
||||
VarRequestURL = "$req_url"
|
||||
VarRequestURI = "$req_uri"
|
||||
VarRequestContentType = "$req_content_type"
|
||||
VarRequestContentLen = "$req_content_length"
|
||||
VarRemoteAddr = "$remote_addr"
|
||||
VarUpstreamScheme = "$upstream_scheme"
|
||||
VarUpstreamHost = "$upstream_host"
|
||||
VarUpstreamPort = "$upstream_port"
|
||||
VarUpstreamAddr = "$upstream_addr"
|
||||
VarUpstreamURL = "$upstream_url"
|
||||
|
||||
VarRespContentType = "$resp_content_type"
|
||||
VarRespContentLen = "$resp_content_length"
|
||||
VarRespStatusCode = "$status_code"
|
||||
)
|
||||
|
||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||
"$req_method": func(req *Request) string { return req.Method },
|
||||
"$req_scheme": func(req *Request) string { return req.URL.Scheme },
|
||||
"$req_host": func(req *Request) string {
|
||||
VarRequestMethod: func(req *Request) string { return req.Method },
|
||||
VarRequestScheme: func(req *Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
},
|
||||
VarRequestHost: func(req *Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
return req.Host
|
||||
}
|
||||
return reqHost
|
||||
},
|
||||
"$req_port": func(req *Request) string {
|
||||
VarRequestPort: func(req *Request) string {
|
||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||
return reqPort
|
||||
},
|
||||
"$req_addr": func(req *Request) string { return req.Host },
|
||||
"$req_path": func(req *Request) string { return req.URL.Path },
|
||||
"$req_query": func(req *Request) string { return req.URL.RawQuery },
|
||||
"$req_url": func(req *Request) string { return req.URL.String() },
|
||||
"$req_uri": func(req *Request) string { return req.URL.RequestURI() },
|
||||
"$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") },
|
||||
"$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
"$remote_addr": func(req *Request) string { return req.RemoteAddr },
|
||||
"$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||
"$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||
"$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||
"$upstream_addr": func(req *Request) string {
|
||||
VarRequestAddr: func(req *Request) string { return req.Host },
|
||||
VarRequestPath: func(req *Request) string { return req.URL.Path },
|
||||
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
|
||||
VarRequestURL: func(req *Request) string { return req.URL.String() },
|
||||
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
|
||||
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
|
||||
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
|
||||
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||
VarUpstreamAddr: func(req *Request) string {
|
||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||
if upPort != "" {
|
||||
|
@ -55,7 +84,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
}
|
||||
return upHost
|
||||
},
|
||||
"$upstream_url": func(req *Request) string {
|
||||
VarUpstreamURL: func(req *Request) string {
|
||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return ""
|
||||
|
@ -71,9 +100,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
|||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
"$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
||||
"$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") },
|
||||
"$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||
}
|
||||
|
||||
func varReplace(req *Request, resp *Response, s string) string {
|
||||
|
@ -99,7 +128,7 @@ func varReplace(req *Request, resp *Response, s string) string {
|
|||
if resp != nil {
|
||||
// Replace response headers
|
||||
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||
header := http.CanonicalHeaderKey(match[14 : len(match)-1])
|
||||
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
|
||||
return resp.Header.Get(header)
|
||||
})
|
||||
}
|
||||
|
|
8
internal/net/http/proxy_response.go
Normal file
8
internal/net/http/proxy_response.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type ProxyResponse struct {
|
||||
*http.Response
|
||||
OriginalRequest *http.Request
|
||||
}
|
|
@ -87,7 +87,7 @@ type ReverseProxy struct {
|
|||
// If ModifyResponse returns an error, ErrorHandler is called
|
||||
// with its error value. If ErrorHandler is nil, its default
|
||||
// implementation is used.
|
||||
ModifyResponse func(*http.Response) error
|
||||
ModifyResponse func(*ProxyResponse) error
|
||||
|
||||
HandlerFunc http.HandlerFunc
|
||||
|
||||
|
@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() {
|
|||
metrics.GetRouteMetrics().UnregisterService(p.TargetName)
|
||||
}
|
||||
|
||||
func rewriteRequestURL(req *http.Request, target types.URL) {
|
||||
targetQuery := target.RawQuery
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL)
|
||||
func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
|
||||
targetQuery := p.TargetURL.RawQuery
|
||||
req.URL.Scheme = p.TargetURL.Scheme
|
||||
req.URL.Host = p.TargetURL.Host
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL.URL, req.URL)
|
||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||
} else {
|
||||
|
@ -251,11 +251,11 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
|||
|
||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||
// and reports whether the request should proceed.
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
|
||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
|
||||
if p.ModifyResponse == nil {
|
||||
return true
|
||||
}
|
||||
if err := p.ModifyResponse(res); err != nil {
|
||||
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
|
||||
res.Body.Close()
|
||||
p.errorHandler(rw, req, err, true)
|
||||
return false
|
||||
|
@ -349,7 +349,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
rewriteRequestURL(outreq, p.TargetURL)
|
||||
p.rewriteRequestURL(outreq)
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := UpgradeType(outreq.Header)
|
||||
|
@ -439,6 +439,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
|
||||
roundTripMutex.Lock()
|
||||
roundTripDone = true
|
||||
roundTripMutex.Unlock()
|
||||
|
@ -459,7 +460,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
if !p.modifyResponse(rw, res, outreq) {
|
||||
if !p.modifyResponse(rw, res, req, outreq) {
|
||||
return
|
||||
}
|
||||
p.handleUpgradeResponse(rw, outreq, res)
|
||||
|
@ -468,7 +469,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
RemoveHopByHopHeaders(res.Header)
|
||||
|
||||
if !p.modifyResponse(rw, res, outreq) {
|
||||
if !p.modifyResponse(rw, res, req, outreq) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue