From eabdd3de00cff13ae8d8d8d6d8976e98278fe839 Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 4 Dec 2024 01:58:17 +0800 Subject: [PATCH] improved middleware variable subsititution --- .../net/http/middleware/modify_request.go | 38 +++-- .../net/http/middleware/modify_response.go | 4 +- internal/net/http/middleware/vars.go | 153 ++++++++++-------- 3 files changed, 109 insertions(+), 86 deletions(-) diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 95037a2..334c18c 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -44,7 +44,7 @@ func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { func (mr *modifyRequest) checkVarSubstitution() { for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} { for _, v := range m { - if strings.Contains(v, "$") { + if strings.ContainsRune(v, '$') { mr.needVarSubstitution = true return } @@ -53,20 +53,32 @@ func (mr *modifyRequest) checkVarSubstitution() { } func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) { - replaceVars := varReplacerDummy - if mr.needVarSubstitution { - replaceVars = varReplacer(req, resp) + if !mr.needVarSubstitution { + for k, v := range mr.SetHeaders { + if req != nil && strings.ToLower(k) == "host" { + defer func() { + req.Host = v + }() + } + headers.Set(k, v) + } + for k, v := range mr.AddHeaders { + headers.Add(k, v) + } + } else { + for k, v := range mr.SetHeaders { + if req != nil && strings.ToLower(k) == "host" { + defer func() { + req.Host = varReplace(req, resp, v) + }() + } + headers.Set(k, varReplace(req, resp, v)) + } + for k, v := range mr.AddHeaders { + headers.Add(k, varReplace(req, resp, v)) + } } - for k, v := range mr.SetHeaders { - if strings.ToLower(k) == "host" { - req.Host = replaceVars(v) - } - headers.Set(k, replaceVars(v)) - } - for k, v := range mr.AddHeaders { - headers.Add(k, replaceVars(v)) - } for _, k := range mr.HideHeaders { headers.Del(k) } diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index ea9b5b6..6ba7fe7 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -1,8 +1,6 @@ package middleware import ( - "net/http" - E "github.com/yusing/go-proxy/internal/error" ) @@ -14,7 +12,7 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyResponse) mr.m = &Middleware{ impl: mr, - modifyResponse: func(resp *http.Response) error { + modifyResponse: func(resp *Response) error { mr.m.AddTraceResponse("before modify response", resp) mr.modifyHeaders(resp.Request, resp, resp.Header) mr.m.AddTraceResponse("after modify response", resp) diff --git a/internal/net/http/middleware/vars.go b/internal/net/http/middleware/vars.go index c776903..68f5f67 100644 --- a/internal/net/http/middleware/vars.go +++ b/internal/net/http/middleware/vars.go @@ -10,71 +10,74 @@ import ( gphttp "github.com/yusing/go-proxy/internal/net/http" ) -type varReplaceFunc func(string) string +type ( + reqVarGetter func(*Request) string + respVarGetter func(*Response) string +) var ( reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`) - reHeader = regexp.MustCompile(`\$header\([\w-]+\)`) + reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`) reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`) reStatic = regexp.MustCompile(`\$[\w_]+`) ) -func varSubsMap(req *Request, resp *Response) map[string]func() string { - reqHost, reqPort, err := net.SplitHostPort(req.Host) - if err != nil { - reqHost = req.Host - } - reqAddr := reqHost - if reqPort != "" { - reqAddr += ":" + reqPort - } - - pairs := map[string]func() string{ - "$req_method": func() string { return req.Method }, - "$req_scheme": func() string { return req.URL.Scheme }, - "$req_host": func() string { return reqHost }, - "$req_port": func() string { return reqPort }, - "$req_addr": func() string { return reqAddr }, - "$req_path": func() string { return req.URL.Path }, - "$req_query": func() string { return req.URL.RawQuery }, - "$req_url": func() string { return req.URL.String() }, - "$req_uri": req.URL.RequestURI, - "$req_content_type": func() string { return req.Header.Get("Content-Type") }, - "$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) }, - "$remote_addr": func() string { return req.RemoteAddr }, - } - - if resp != nil { - pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") } - pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") } - pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) } - } - - upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) - if upScheme == "" { - return pairs - } - - upHost := req.Header.Get(gphttp.HeaderUpstreamHost) - upPort := req.Header.Get(gphttp.HeaderUpstreamPort) - upAddr := upHost - if upPort != "" { - upAddr += ":" + upPort - } - upURL := upScheme + "://" + upAddr - - pairs["$upstream_scheme"] = func() string { return upScheme } - pairs["$upstream_host"] = func() string { return upHost } - pairs["$upstream_port"] = func() string { return upPort } - pairs["$upstream_addr"] = func() string { return upAddr } - pairs["$upstream_url"] = func() string { return upURL } - - return pairs +var staticReqVarSubsMap = map[string]reqVarGetter{ + "$req_method": func(req *Request) string { return req.Method }, + "$req_scheme": func(req *Request) string { return req.URL.Scheme }, + "$req_host": func(req *Request) string { + reqHost, _, err := net.SplitHostPort(req.Host) + if err != nil { + return req.Host + } + return reqHost + }, + "$req_port": func(req *Request) string { + _, reqPort, _ := net.SplitHostPort(req.Host) + return reqPort + }, + "$req_addr": func(req *Request) string { return req.Host }, + "$req_path": func(req *Request) string { return req.URL.Path }, + "$req_query": func(req *Request) string { return req.URL.RawQuery }, + "$req_url": func(req *Request) string { return req.URL.String() }, + "$req_uri": func(req *Request) string { return req.URL.RequestURI() }, + "$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") }, + "$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) }, + "$remote_addr": func(req *Request) string { return req.RemoteAddr }, + "$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) }, + "$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) }, + "$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) }, + "$upstream_addr": func(req *Request) string { + upHost := req.Header.Get(gphttp.HeaderUpstreamHost) + upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + if upPort != "" { + return upHost + ":" + upPort + } + return upHost + }, + "$upstream_url": func(req *Request) string { + upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) + if upScheme == "" { + return "" + } + upHost := req.Header.Get(gphttp.HeaderUpstreamHost) + upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + upAddr := upHost + if upPort != "" { + upAddr += ":" + upPort + } + return upScheme + "://" + upAddr + }, } -func varReplacer(req *Request, resp *Response) varReplaceFunc { - pairs := varSubsMap(req, resp) - return func(s string) string { +var staticRespVarSubsMap = map[string]respVarGetter{ + "$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") }, + "$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") }, + "$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) }, +} + +func varReplace(req *Request, resp *Response, s string) string { + if req != nil { // Replace query parameters s = reArg.ReplaceAllStringFunc(s, func(match string) string { name := match[5 : len(match)-1] @@ -86,29 +89,39 @@ func varReplacer(req *Request, resp *Response) varReplaceFunc { return "" }) - // Replace headers - s = reHeader.ReplaceAllStringFunc(s, func(match string) string { + // Replace request headers + s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string { header := http.CanonicalHeaderKey(match[8 : len(match)-1]) return req.Header.Get(header) }) + } - if resp != nil { - s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string { - header := http.CanonicalHeaderKey(match[14 : len(match)-1]) - return resp.Header.Get(header) - }) - } + if resp != nil { + // Replace response headers + s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string { + header := http.CanonicalHeaderKey(match[14 : len(match)-1]) + return resp.Header.Get(header) + }) + } - // Replace static variables - return reStatic.ReplaceAllStringFunc(s, func(match string) string { - if fn, ok := pairs[match]; ok { - return fn() + // Replace static variables + if req != nil { + s = reStatic.ReplaceAllStringFunc(s, func(match string) string { + if fn, ok := staticReqVarSubsMap[match]; ok { + return fn(req) + } + return match + }) + } + + if resp != nil { + s = reStatic.ReplaceAllStringFunc(s, func(match string) string { + if fn, ok := staticRespVarSubsMap[match]; ok { + return fn(resp) } return match }) } -} -func varReplacerDummy(s string) string { return s }