mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix modifyResponse middleware incorrect variable substitution
This commit is contained in:
parent
a9f6c4eb20
commit
aff8a3b401
10 changed files with 255 additions and 104 deletions
|
@ -137,13 +137,13 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *Response) error {
|
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
|
||||||
fa.setAuthCookies(resp, authCookies)
|
fa.setAuthCookies(resp, authCookies)
|
||||||
return nil
|
return nil
|
||||||
}), req)
|
}), req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
|
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
|
||||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,14 +16,14 @@ type (
|
||||||
ReverseProxy = gphttp.ReverseProxy
|
ReverseProxy = gphttp.ReverseProxy
|
||||||
ProxyRequest = gphttp.ProxyRequest
|
ProxyRequest = gphttp.ProxyRequest
|
||||||
Request = http.Request
|
Request = http.Request
|
||||||
Response = http.Response
|
Response = gphttp.ProxyResponse
|
||||||
ResponseWriter = http.ResponseWriter
|
ResponseWriter = http.ResponseWriter
|
||||||
Header = http.Header
|
Header = http.Header
|
||||||
Cookie = http.Cookie
|
Cookie = http.Cookie
|
||||||
|
|
||||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||||
RewriteFunc func(req *Request)
|
RewriteFunc func(req *Request)
|
||||||
ModifyResponseFunc = gphttp.ModifyResponseFunc
|
ModifyResponseFunc func(*Response) error
|
||||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
|
||||||
|
|
||||||
OptionsRaw = map[string]any
|
OptionsRaw = map[string]any
|
||||||
|
@ -116,7 +116,9 @@ func (m *Middleware) ModifyResponse(resp *Response) error {
|
||||||
|
|
||||||
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
if m.modifyResponse != nil {
|
if m.modifyResponse != nil {
|
||||||
w = gphttp.NewModifyResponseWriter(w, r, m.modifyResponse)
|
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
|
||||||
|
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if m.before != nil {
|
if m.before != nil {
|
||||||
m.before(next, w, r)
|
m.before(next, w, r)
|
||||||
|
@ -176,7 +178,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||||
|
|
||||||
if mid.before != nil {
|
if mid.before != nil {
|
||||||
ori := rp.HandlerFunc
|
ori := rp.HandlerFunc
|
||||||
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
|
||||||
mid.before(ori, w, r)
|
mid.before(ori, w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,7 +186,7 @@ func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||||
if mid.modifyResponse != nil {
|
if mid.modifyResponse != nil {
|
||||||
if rp.ModifyResponse != nil {
|
if rp.ModifyResponse != nil {
|
||||||
ori := rp.ModifyResponse
|
ori := rp.ModifyResponse
|
||||||
rp.ModifyResponse = func(res *http.Response) error {
|
rp.ModifyResponse = func(res *Response) error {
|
||||||
if err := mid.modifyResponse(res); err != nil {
|
if err := mid.modifyResponse(res); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,30 +10,30 @@ import (
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetModifyRequest(t *testing.T) {
|
func TestModifyRequest(t *testing.T) {
|
||||||
opts := OptionsRaw{
|
opts := OptionsRaw{
|
||||||
"set_headers": map[string]string{
|
"set_headers": map[string]string{
|
||||||
"User-Agent": "go-proxy/v0.5.0",
|
"User-Agent": "go-proxy/v0.5.0",
|
||||||
"Host": "$upstream_addr",
|
"Host": VarUpstreamAddr,
|
||||||
"X-Test-Req-Method": "$req_method",
|
"X-Test-Req-Method": VarRequestMethod,
|
||||||
"X-Test-Req-Scheme": "$req_scheme",
|
"X-Test-Req-Scheme": VarRequestScheme,
|
||||||
"X-Test-Req-Host": "$req_host",
|
"X-Test-Req-Host": VarRequestHost,
|
||||||
"X-Test-Req-Port": "$req_port",
|
"X-Test-Req-Port": VarRequestPort,
|
||||||
"X-Test-Req-Addr": "$req_addr",
|
"X-Test-Req-Addr": VarRequestAddr,
|
||||||
"X-Test-Req-Path": "$req_path",
|
"X-Test-Req-Path": VarRequestPath,
|
||||||
"X-Test-Req-Query": "$req_query",
|
"X-Test-Req-Query": VarRequestQuery,
|
||||||
"X-Test-Req-Url": "$req_url",
|
"X-Test-Req-Url": VarRequestURL,
|
||||||
"X-Test-Req-Uri": "$req_uri",
|
"X-Test-Req-Uri": VarRequestURI,
|
||||||
"X-Test-Req-Content-Type": "$req_content_type",
|
"X-Test-Req-Content-Type": VarRequestContentType,
|
||||||
"X-Test-Req-Content-Length": "$req_content_length",
|
"X-Test-Req-Content-Length": VarRequestContentLen,
|
||||||
"X-Test-Remote-Addr": "$remote_addr",
|
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||||
"X-Test-Upstream-Scheme": "$upstream_scheme",
|
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||||
"X-Test-Upstream-Host": "$upstream_host",
|
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||||
"X-Test-Upstream-Port": "$upstream_port",
|
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||||
"X-Test-Upstream-Addr": "$upstream_addr",
|
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||||
"X-Test-Upstream-Url": "$upstream_url",
|
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||||
"X-Test-Content-Type": "$header(Content-Type)",
|
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||||
},
|
},
|
||||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||||
"hide_headers": []string{"Accept"},
|
"hide_headers": []string{"Accept"},
|
||||||
|
@ -84,7 +84,7 @@ func TestSetModifyRequest(t *testing.T) {
|
||||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||||
|
|
||||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json")
|
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||||
|
|
||||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,10 +14,13 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
mr := new(modifyResponse)
|
mr := new(modifyResponse)
|
||||||
mr.m = &Middleware{
|
mr.m = &Middleware{
|
||||||
impl: mr,
|
impl: mr,
|
||||||
|
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||||
|
next(w, r)
|
||||||
|
},
|
||||||
modifyResponse: func(resp *Response) error {
|
modifyResponse: func(resp *Response) error {
|
||||||
mr.m.AddTraceResponse("before modify response", resp)
|
mr.m.AddTraceResponse("before modify response", resp.Response)
|
||||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
mr.modifyHeaders(resp.OriginalRequest, resp, resp.Header)
|
||||||
mr.m.AddTraceResponse("after modify response", resp)
|
mr.m.AddTraceResponse("after modify response", resp.Response)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,41 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/net/types"
|
||||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetModifyResponse(t *testing.T) {
|
func TestModifyResponse(t *testing.T) {
|
||||||
opts := OptionsRaw{
|
opts := OptionsRaw{
|
||||||
"set_headers": map[string]string{"User-Agent": "go-proxy/v0.5.0"},
|
"set_headers": map[string]string{
|
||||||
|
"X-Test-Resp-Status": VarRespStatusCode,
|
||||||
|
"X-Test-Resp-Content-Type": VarRespContentType,
|
||||||
|
"X-Test-Resp-Content-Length": VarRespContentLen,
|
||||||
|
"X-Test-Resp-Header-Content-Type": "$resp_header(Content-Type)",
|
||||||
|
|
||||||
|
"X-Test-Req-Method": VarRequestMethod,
|
||||||
|
"X-Test-Req-Scheme": VarRequestScheme,
|
||||||
|
"X-Test-Req-Host": VarRequestHost,
|
||||||
|
"X-Test-Req-Port": VarRequestPort,
|
||||||
|
"X-Test-Req-Addr": VarRequestAddr,
|
||||||
|
"X-Test-Req-Path": VarRequestPath,
|
||||||
|
"X-Test-Req-Query": VarRequestQuery,
|
||||||
|
"X-Test-Req-Url": VarRequestURL,
|
||||||
|
"X-Test-Req-Uri": VarRequestURI,
|
||||||
|
"X-Test-Remote-Addr": VarRemoteAddr,
|
||||||
|
"X-Test-Upstream-Scheme": VarUpstreamScheme,
|
||||||
|
"X-Test-Upstream-Host": VarUpstreamHost,
|
||||||
|
"X-Test-Upstream-Port": VarUpstreamPort,
|
||||||
|
"X-Test-Upstream-Addr": VarUpstreamAddr,
|
||||||
|
"X-Test-Upstream-Url": VarUpstreamURL,
|
||||||
|
"X-Test-Header-Content-Type": "$header(Content-Type)",
|
||||||
|
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||||
|
},
|
||||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||||
"hide_headers": []string{"Accept"},
|
"hide_headers": []string{"Accept"},
|
||||||
}
|
}
|
||||||
|
@ -22,14 +48,50 @@ func TestSetModifyResponse(t *testing.T) {
|
||||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
ExpectDeepEqual(t, mr.impl.(*modifyResponse).HideHeaders, opts["hide_headers"].([]string))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("request_headers", func(t *testing.T) {
|
t.Run("response_headers", func(t *testing.T) {
|
||||||
|
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||||
|
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||||
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
result, err := newMiddlewareTest(ModifyResponse, &testArgs{
|
||||||
middlewareOpt: opts,
|
middlewareOpt: opts,
|
||||||
|
reqURL: reqURL,
|
||||||
|
upstreamURL: upstreamURL,
|
||||||
|
body: bytes.Repeat([]byte("a"), 100),
|
||||||
|
headers: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
},
|
||||||
|
respHeaders: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
},
|
||||||
|
respBody: bytes.Repeat([]byte("a"), 50),
|
||||||
|
respStatus: http.StatusOK,
|
||||||
})
|
})
|
||||||
ExpectNoError(t, err)
|
ExpectNoError(t, err)
|
||||||
ExpectEqual(t, result.ResponseHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
|
||||||
t.Log(result.ResponseHeaders.Get("Accept-Encoding"))
|
|
||||||
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
ExpectTrue(t, slices.Contains(result.ResponseHeaders.Values("Accept-Encoding"), "test-value"))
|
||||||
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
ExpectEqual(t, result.ResponseHeaders.Get("Accept"), "")
|
||||||
|
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Status"), "200")
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Type"), "application/json")
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Content-Length"), "50")
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Resp-Header-Content-Type"), "application/json")
|
||||||
|
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Method"), http.MethodGet)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||||
|
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||||
|
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Header-Content-Type"), "application/json")
|
||||||
|
ExpectEqual(t, result.ResponseHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,24 +34,39 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
type requestRecorder struct {
|
type requestRecorder struct {
|
||||||
|
args *testArgs
|
||||||
|
|
||||||
parent http.RoundTripper
|
parent http.RoundTripper
|
||||||
headers http.Header
|
headers http.Header
|
||||||
remoteAddr string
|
remoteAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
func newRequestRecorder(args *testArgs) *requestRecorder {
|
||||||
|
return &requestRecorder{args: args}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||||
rt.headers = req.Header
|
rt.headers = req.Header
|
||||||
rt.remoteAddr = req.RemoteAddr
|
rt.remoteAddr = req.RemoteAddr
|
||||||
if rt.parent != nil {
|
if rt.parent != nil {
|
||||||
return rt.parent.RoundTrip(req)
|
resp, err = rt.parent.RoundTrip(req)
|
||||||
|
} else {
|
||||||
|
resp = &http.Response{
|
||||||
|
Status: http.StatusText(rt.args.respStatus),
|
||||||
|
StatusCode: rt.args.respStatus,
|
||||||
|
Header: testHeaders,
|
||||||
|
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
|
||||||
|
ContentLength: int64(len(rt.args.respBody)),
|
||||||
|
Request: req,
|
||||||
|
TLS: req.TLS,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &http.Response{
|
if err == nil {
|
||||||
StatusCode: http.StatusOK,
|
for k, v := range rt.args.respHeaders {
|
||||||
Header: testHeaders,
|
resp.Header[k] = v
|
||||||
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
}
|
||||||
Request: req,
|
}
|
||||||
TLS: req.TLS,
|
return resp, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TestResult struct {
|
type TestResult struct {
|
||||||
|
@ -64,56 +79,84 @@ type TestResult struct {
|
||||||
|
|
||||||
type testArgs struct {
|
type testArgs struct {
|
||||||
middlewareOpt OptionsRaw
|
middlewareOpt OptionsRaw
|
||||||
reqURL types.URL
|
|
||||||
upstreamURL types.URL
|
upstreamURL types.URL
|
||||||
body []byte
|
|
||||||
realRoundTrip bool
|
realRoundTrip bool
|
||||||
headers http.Header
|
|
||||||
|
reqURL types.URL
|
||||||
|
reqMethod string
|
||||||
|
headers http.Header
|
||||||
|
body []byte
|
||||||
|
|
||||||
|
respHeaders http.Header
|
||||||
|
respBody []byte
|
||||||
|
respStatus int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
func (args *testArgs) setDefaults() {
|
||||||
var body io.Reader
|
|
||||||
var rr requestRecorder
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if args == nil {
|
|
||||||
args = new(testArgs)
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.body != nil {
|
|
||||||
body = bytes.NewReader(args.body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.reqURL.Nil() {
|
if args.reqURL.Nil() {
|
||||||
args.reqURL = E.Must(types.ParseURL("https://example.com"))
|
args.reqURL = E.Must(types.ParseURL("https://example.com"))
|
||||||
}
|
}
|
||||||
|
if args.reqMethod == "" {
|
||||||
req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body)
|
args.reqMethod = http.MethodGet
|
||||||
for k, v := range args.headers {
|
|
||||||
req.Header[k] = v
|
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
if args.upstreamURL.Nil() {
|
if args.upstreamURL.Nil() {
|
||||||
args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||||
}
|
}
|
||||||
|
if args.respHeaders == nil {
|
||||||
|
args.respHeaders = http.Header{}
|
||||||
|
}
|
||||||
|
if args.respBody == nil {
|
||||||
|
args.respBody = []byte("OK")
|
||||||
|
}
|
||||||
|
if args.respStatus == 0 {
|
||||||
|
args.respStatus = http.StatusOK
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (args *testArgs) bodyReader() io.Reader {
|
||||||
|
if args.body != nil {
|
||||||
|
return bytes.NewReader(args.body)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||||
|
if args == nil {
|
||||||
|
args = new(testArgs)
|
||||||
|
}
|
||||||
|
args.setDefaults()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
||||||
|
for k, v := range args.headers {
|
||||||
|
req.Header[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
rr := newRequestRecorder(args)
|
||||||
if args.realRoundTrip {
|
if args.realRoundTrip {
|
||||||
rr.parent = http.DefaultTransport
|
rr.parent = http.DefaultTransport
|
||||||
}
|
}
|
||||||
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr)
|
|
||||||
|
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
||||||
|
|
||||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||||
if setOptErr != nil {
|
if setOptErr != nil {
|
||||||
return nil, setOptErr
|
return nil, setOptErr
|
||||||
}
|
}
|
||||||
|
|
||||||
patchReverseProxy(rp, []*Middleware{mid})
|
patchReverseProxy(rp, []*Middleware{mid})
|
||||||
rp.ServeHTTP(w, req)
|
rp.ServeHTTP(w, req)
|
||||||
|
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
data, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.From(err)
|
return nil, E.From(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &TestResult{
|
return &TestResult{
|
||||||
RequestHeaders: rr.headers,
|
RequestHeaders: rr.headers,
|
||||||
ResponseHeaders: resp.Header,
|
ResponseHeaders: resp.Header,
|
||||||
|
|
|
@ -2,6 +2,7 @@ package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -42,7 +43,7 @@ func (tr *Trace) WithRequest(req *Request) *Trace {
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *Trace) WithResponse(resp *Response) *Trace {
|
func (tr *Trace) WithResponse(resp *http.Response) *Trace {
|
||||||
if tr == nil {
|
if tr == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -103,7 +104,7 @@ func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
||||||
return m.AddTracef("%s", msg).WithRequest(req)
|
return m.AddTracef("%s", msg).WithRequest(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
|
func (m *Middleware) AddTraceResponse(msg string, resp *http.Response) *Trace {
|
||||||
if !m.trace {
|
if !m.trace {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,32 +22,61 @@ var (
|
||||||
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
VarRequestMethod = "$req_method"
|
||||||
|
VarRequestScheme = "$req_scheme"
|
||||||
|
VarRequestHost = "$req_host"
|
||||||
|
VarRequestPort = "$req_port"
|
||||||
|
VarRequestPath = "$req_path"
|
||||||
|
VarRequestAddr = "$req_addr"
|
||||||
|
VarRequestQuery = "$req_query"
|
||||||
|
VarRequestURL = "$req_url"
|
||||||
|
VarRequestURI = "$req_uri"
|
||||||
|
VarRequestContentType = "$req_content_type"
|
||||||
|
VarRequestContentLen = "$req_content_length"
|
||||||
|
VarRemoteAddr = "$remote_addr"
|
||||||
|
VarUpstreamScheme = "$upstream_scheme"
|
||||||
|
VarUpstreamHost = "$upstream_host"
|
||||||
|
VarUpstreamPort = "$upstream_port"
|
||||||
|
VarUpstreamAddr = "$upstream_addr"
|
||||||
|
VarUpstreamURL = "$upstream_url"
|
||||||
|
|
||||||
|
VarRespContentType = "$resp_content_type"
|
||||||
|
VarRespContentLen = "$resp_content_length"
|
||||||
|
VarRespStatusCode = "$status_code"
|
||||||
|
)
|
||||||
|
|
||||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
"$req_method": func(req *Request) string { return req.Method },
|
VarRequestMethod: func(req *Request) string { return req.Method },
|
||||||
"$req_scheme": func(req *Request) string { return req.URL.Scheme },
|
VarRequestScheme: func(req *Request) string {
|
||||||
"$req_host": func(req *Request) string {
|
if req.TLS != nil {
|
||||||
|
return "https"
|
||||||
|
}
|
||||||
|
return "http"
|
||||||
|
},
|
||||||
|
VarRequestHost: func(req *Request) string {
|
||||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return req.Host
|
return req.Host
|
||||||
}
|
}
|
||||||
return reqHost
|
return reqHost
|
||||||
},
|
},
|
||||||
"$req_port": func(req *Request) string {
|
VarRequestPort: func(req *Request) string {
|
||||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||||
return reqPort
|
return reqPort
|
||||||
},
|
},
|
||||||
"$req_addr": func(req *Request) string { return req.Host },
|
VarRequestAddr: func(req *Request) string { return req.Host },
|
||||||
"$req_path": func(req *Request) string { return req.URL.Path },
|
VarRequestPath: func(req *Request) string { return req.URL.Path },
|
||||||
"$req_query": func(req *Request) string { return req.URL.RawQuery },
|
VarRequestQuery: func(req *Request) string { return req.URL.RawQuery },
|
||||||
"$req_url": func(req *Request) string { return req.URL.String() },
|
VarRequestURL: func(req *Request) string { return req.URL.String() },
|
||||||
"$req_uri": func(req *Request) string { return req.URL.RequestURI() },
|
VarRequestURI: func(req *Request) string { return req.URL.RequestURI() },
|
||||||
"$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") },
|
VarRequestContentType: func(req *Request) string { return req.Header.Get("Content-Type") },
|
||||||
"$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
VarRequestContentLen: func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||||
"$remote_addr": func(req *Request) string { return req.RemoteAddr },
|
VarRemoteAddr: func(req *Request) string { return req.RemoteAddr },
|
||||||
"$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
VarUpstreamScheme: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||||
"$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
VarUpstreamHost: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||||
"$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
VarUpstreamPort: func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||||
"$upstream_addr": func(req *Request) string {
|
VarUpstreamAddr: func(req *Request) string {
|
||||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||||
if upPort != "" {
|
if upPort != "" {
|
||||||
|
@ -55,7 +84,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
}
|
}
|
||||||
return upHost
|
return upHost
|
||||||
},
|
},
|
||||||
"$upstream_url": func(req *Request) string {
|
VarUpstreamURL: func(req *Request) string {
|
||||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||||
if upScheme == "" {
|
if upScheme == "" {
|
||||||
return ""
|
return ""
|
||||||
|
@ -71,9 +100,9 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
}
|
}
|
||||||
|
|
||||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||||
"$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
VarRespContentType: func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
||||||
"$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") },
|
VarRespContentLen: func(resp *Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
|
||||||
"$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
VarRespStatusCode: func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||||
}
|
}
|
||||||
|
|
||||||
func varReplace(req *Request, resp *Response, s string) string {
|
func varReplace(req *Request, resp *Response, s string) string {
|
||||||
|
@ -99,7 +128,7 @@ func varReplace(req *Request, resp *Response, s string) string {
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
// Replace response headers
|
// Replace response headers
|
||||||
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
header := http.CanonicalHeaderKey(match[14 : len(match)-1])
|
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
|
||||||
return resp.Header.Get(header)
|
return resp.Header.Get(header)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
8
internal/net/http/proxy_response.go
Normal file
8
internal/net/http/proxy_response.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type ProxyResponse struct {
|
||||||
|
*http.Response
|
||||||
|
OriginalRequest *http.Request
|
||||||
|
}
|
|
@ -87,7 +87,7 @@ type ReverseProxy struct {
|
||||||
// If ModifyResponse returns an error, ErrorHandler is called
|
// If ModifyResponse returns an error, ErrorHandler is called
|
||||||
// with its error value. If ErrorHandler is nil, its default
|
// with its error value. If ErrorHandler is nil, its default
|
||||||
// implementation is used.
|
// implementation is used.
|
||||||
ModifyResponse func(*http.Response) error
|
ModifyResponse func(*ProxyResponse) error
|
||||||
|
|
||||||
HandlerFunc http.HandlerFunc
|
HandlerFunc http.HandlerFunc
|
||||||
|
|
||||||
|
@ -199,11 +199,11 @@ func (p *ReverseProxy) UnregisterMetrics() {
|
||||||
metrics.GetRouteMetrics().UnregisterService(p.TargetName)
|
metrics.GetRouteMetrics().UnregisterService(p.TargetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func rewriteRequestURL(req *http.Request, target types.URL) {
|
func (p *ReverseProxy) rewriteRequestURL(req *http.Request) {
|
||||||
targetQuery := target.RawQuery
|
targetQuery := p.TargetURL.RawQuery
|
||||||
req.URL.Scheme = target.Scheme
|
req.URL.Scheme = p.TargetURL.Scheme
|
||||||
req.URL.Host = target.Host
|
req.URL.Host = p.TargetURL.Host
|
||||||
req.URL.Path, req.URL.RawPath = joinURLPath(target.URL, req.URL)
|
req.URL.Path, req.URL.RawPath = joinURLPath(p.TargetURL.URL, req.URL)
|
||||||
if targetQuery == "" || req.URL.RawQuery == "" {
|
if targetQuery == "" || req.URL.RawQuery == "" {
|
||||||
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
||||||
} else {
|
} else {
|
||||||
|
@ -251,11 +251,11 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err
|
||||||
|
|
||||||
// modifyResponse conditionally runs the optional ModifyResponse hook
|
// modifyResponse conditionally runs the optional ModifyResponse hook
|
||||||
// and reports whether the request should proceed.
|
// and reports whether the request should proceed.
|
||||||
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
|
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, oriReq, req *http.Request) bool {
|
||||||
if p.ModifyResponse == nil {
|
if p.ModifyResponse == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err := p.ModifyResponse(res); err != nil {
|
if err := p.ModifyResponse(&ProxyResponse{Response: res, OriginalRequest: oriReq}); err != nil {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
p.errorHandler(rw, req, err, true)
|
p.errorHandler(rw, req, err, true)
|
||||||
return false
|
return false
|
||||||
|
@ -349,7 +349,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriteRequestURL(outreq, p.TargetURL)
|
p.rewriteRequestURL(outreq)
|
||||||
outreq.Close = false
|
outreq.Close = false
|
||||||
|
|
||||||
reqUpType := UpgradeType(outreq.Header)
|
reqUpType := UpgradeType(outreq.Header)
|
||||||
|
@ -439,6 +439,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||||
|
|
||||||
res, err := transport.RoundTrip(outreq)
|
res, err := transport.RoundTrip(outreq)
|
||||||
|
|
||||||
roundTripMutex.Lock()
|
roundTripMutex.Lock()
|
||||||
roundTripDone = true
|
roundTripDone = true
|
||||||
roundTripMutex.Unlock()
|
roundTripMutex.Unlock()
|
||||||
|
@ -459,7 +460,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||||
if !p.modifyResponse(rw, res, outreq) {
|
if !p.modifyResponse(rw, res, req, outreq) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.handleUpgradeResponse(rw, outreq, res)
|
p.handleUpgradeResponse(rw, outreq, res)
|
||||||
|
@ -468,7 +469,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
RemoveHopByHopHeaders(res.Header)
|
RemoveHopByHopHeaders(res.Header)
|
||||||
|
|
||||||
if !p.modifyResponse(rw, res, outreq) {
|
if !p.modifyResponse(rw, res, req, outreq) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue