improved middleware variable subsititution

This commit is contained in:
yusing 2024-12-04 01:58:17 +08:00
parent fcfb7a0105
commit eabdd3de00
3 changed files with 109 additions and 86 deletions

View file

@ -44,7 +44,7 @@ func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
func (mr *modifyRequest) checkVarSubstitution() { func (mr *modifyRequest) checkVarSubstitution() {
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} { for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
for _, v := range m { for _, v := range m {
if strings.Contains(v, "$") { if strings.ContainsRune(v, '$') {
mr.needVarSubstitution = true mr.needVarSubstitution = true
return return
} }
@ -53,20 +53,32 @@ func (mr *modifyRequest) checkVarSubstitution() {
} }
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) { func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
replaceVars := varReplacerDummy if !mr.needVarSubstitution {
if mr.needVarSubstitution {
replaceVars = varReplacer(req, resp)
}
for k, v := range mr.SetHeaders { for k, v := range mr.SetHeaders {
if strings.ToLower(k) == "host" { if req != nil && strings.ToLower(k) == "host" {
req.Host = replaceVars(v) defer func() {
req.Host = v
}()
} }
headers.Set(k, replaceVars(v)) headers.Set(k, v)
} }
for k, v := range mr.AddHeaders { for k, v := range mr.AddHeaders {
headers.Add(k, replaceVars(v)) headers.Add(k, v)
} }
} else {
for k, v := range mr.SetHeaders {
if req != nil && strings.ToLower(k) == "host" {
defer func() {
req.Host = varReplace(req, resp, v)
}()
}
headers.Set(k, varReplace(req, resp, v))
}
for k, v := range mr.AddHeaders {
headers.Add(k, varReplace(req, resp, v))
}
}
for _, k := range mr.HideHeaders { for _, k := range mr.HideHeaders {
headers.Del(k) headers.Del(k)
} }

View file

@ -1,8 +1,6 @@
package middleware package middleware
import ( import (
"net/http"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
@ -14,7 +12,7 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse) mr := new(modifyResponse)
mr.m = &Middleware{ mr.m = &Middleware{
impl: mr, impl: mr,
modifyResponse: func(resp *http.Response) error { modifyResponse: func(resp *Response) error {
mr.m.AddTraceResponse("before modify response", resp) mr.m.AddTraceResponse("before modify response", resp)
mr.modifyHeaders(resp.Request, resp, resp.Header) mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp) mr.m.AddTraceResponse("after modify response", resp)

View file

@ -10,71 +10,74 @@ import (
gphttp "github.com/yusing/go-proxy/internal/net/http" gphttp "github.com/yusing/go-proxy/internal/net/http"
) )
type varReplaceFunc func(string) string type (
reqVarGetter func(*Request) string
respVarGetter func(*Response) string
)
var ( var (
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`) reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
reHeader = regexp.MustCompile(`\$header\([\w-]+\)`) reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`) reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
reStatic = regexp.MustCompile(`\$[\w_]+`) reStatic = regexp.MustCompile(`\$[\w_]+`)
) )
func varSubsMap(req *Request, resp *Response) map[string]func() string { var staticReqVarSubsMap = map[string]reqVarGetter{
reqHost, reqPort, err := net.SplitHostPort(req.Host) "$req_method": func(req *Request) string { return req.Method },
"$req_scheme": func(req *Request) string { return req.URL.Scheme },
"$req_host": func(req *Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil { if err != nil {
reqHost = req.Host return req.Host
} }
reqAddr := reqHost return reqHost
if reqPort != "" { },
reqAddr += ":" + reqPort "$req_port": func(req *Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
"$req_addr": func(req *Request) string { return req.Host },
"$req_path": func(req *Request) string { return req.URL.Path },
"$req_query": func(req *Request) string { return req.URL.RawQuery },
"$req_url": func(req *Request) string { return req.URL.String() },
"$req_uri": func(req *Request) string { return req.URL.RequestURI() },
"$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") },
"$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
"$remote_addr": func(req *Request) string { return req.RemoteAddr },
"$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
"$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
"$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
"$upstream_addr": func(req *Request) string {
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
if upPort != "" {
return upHost + ":" + upPort
} }
return upHost
pairs := map[string]func() string{ },
"$req_method": func() string { return req.Method }, "$upstream_url": func(req *Request) string {
"$req_scheme": func() string { return req.URL.Scheme },
"$req_host": func() string { return reqHost },
"$req_port": func() string { return reqPort },
"$req_addr": func() string { return reqAddr },
"$req_path": func() string { return req.URL.Path },
"$req_query": func() string { return req.URL.RawQuery },
"$req_url": func() string { return req.URL.String() },
"$req_uri": req.URL.RequestURI,
"$req_content_type": func() string { return req.Header.Get("Content-Type") },
"$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) },
"$remote_addr": func() string { return req.RemoteAddr },
}
if resp != nil {
pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") }
pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") }
pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) }
}
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" { if upScheme == "" {
return pairs return ""
} }
upHost := req.Header.Get(gphttp.HeaderUpstreamHost) upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort) upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upAddr := upHost upAddr := upHost
if upPort != "" { if upPort != "" {
upAddr += ":" + upPort upAddr += ":" + upPort
} }
upURL := upScheme + "://" + upAddr return upScheme + "://" + upAddr
},
pairs["$upstream_scheme"] = func() string { return upScheme }
pairs["$upstream_host"] = func() string { return upHost }
pairs["$upstream_port"] = func() string { return upPort }
pairs["$upstream_addr"] = func() string { return upAddr }
pairs["$upstream_url"] = func() string { return upURL }
return pairs
} }
func varReplacer(req *Request, resp *Response) varReplaceFunc { var staticRespVarSubsMap = map[string]respVarGetter{
pairs := varSubsMap(req, resp) "$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") },
return func(s string) string { "$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") },
"$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *Request, resp *Response, s string) string {
if req != nil {
// Replace query parameters // Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string { s = reArg.ReplaceAllStringFunc(s, func(match string) string {
name := match[5 : len(match)-1] name := match[5 : len(match)-1]
@ -86,13 +89,15 @@ func varReplacer(req *Request, resp *Response) varReplaceFunc {
return "" return ""
}) })
// Replace headers // Replace request headers
s = reHeader.ReplaceAllStringFunc(s, func(match string) string { s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[8 : len(match)-1]) header := http.CanonicalHeaderKey(match[8 : len(match)-1])
return req.Header.Get(header) return req.Header.Get(header)
}) })
}
if resp != nil { if resp != nil {
// Replace response headers
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string { s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[14 : len(match)-1]) header := http.CanonicalHeaderKey(match[14 : len(match)-1])
return resp.Header.Get(header) return resp.Header.Get(header)
@ -100,15 +105,23 @@ func varReplacer(req *Request, resp *Response) varReplaceFunc {
} }
// Replace static variables // Replace static variables
return reStatic.ReplaceAllStringFunc(s, func(match string) string { if req != nil {
if fn, ok := pairs[match]; ok { s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
return fn() if fn, ok := staticReqVarSubsMap[match]; ok {
return fn(req)
} }
return match return match
}) })
} }
if resp != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticRespVarSubsMap[match]; ok {
return fn(resp)
}
return match
})
} }
func varReplacerDummy(s string) string {
return s return s
} }