mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-22 20:24:03 +02:00
improved middleware variable subsititution
This commit is contained in:
parent
fcfb7a0105
commit
eabdd3de00
3 changed files with 109 additions and 86 deletions
|
@ -44,7 +44,7 @@ func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
func (mr *modifyRequest) checkVarSubstitution() {
|
func (mr *modifyRequest) checkVarSubstitution() {
|
||||||
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
||||||
for _, v := range m {
|
for _, v := range m {
|
||||||
if strings.Contains(v, "$") {
|
if strings.ContainsRune(v, '$') {
|
||||||
mr.needVarSubstitution = true
|
mr.needVarSubstitution = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -53,20 +53,32 @@ func (mr *modifyRequest) checkVarSubstitution() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
|
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
|
||||||
replaceVars := varReplacerDummy
|
if !mr.needVarSubstitution {
|
||||||
if mr.needVarSubstitution {
|
for k, v := range mr.SetHeaders {
|
||||||
replaceVars = varReplacer(req, resp)
|
if req != nil && strings.ToLower(k) == "host" {
|
||||||
|
defer func() {
|
||||||
|
req.Host = v
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
headers.Set(k, v)
|
||||||
|
}
|
||||||
|
for k, v := range mr.AddHeaders {
|
||||||
|
headers.Add(k, v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for k, v := range mr.SetHeaders {
|
||||||
|
if req != nil && strings.ToLower(k) == "host" {
|
||||||
|
defer func() {
|
||||||
|
req.Host = varReplace(req, resp, v)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
headers.Set(k, varReplace(req, resp, v))
|
||||||
|
}
|
||||||
|
for k, v := range mr.AddHeaders {
|
||||||
|
headers.Add(k, varReplace(req, resp, v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range mr.SetHeaders {
|
|
||||||
if strings.ToLower(k) == "host" {
|
|
||||||
req.Host = replaceVars(v)
|
|
||||||
}
|
|
||||||
headers.Set(k, replaceVars(v))
|
|
||||||
}
|
|
||||||
for k, v := range mr.AddHeaders {
|
|
||||||
headers.Add(k, replaceVars(v))
|
|
||||||
}
|
|
||||||
for _, k := range mr.HideHeaders {
|
for _, k := range mr.HideHeaders {
|
||||||
headers.Del(k)
|
headers.Del(k)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,7 +12,7 @@ func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
mr := new(modifyResponse)
|
mr := new(modifyResponse)
|
||||||
mr.m = &Middleware{
|
mr.m = &Middleware{
|
||||||
impl: mr,
|
impl: mr,
|
||||||
modifyResponse: func(resp *http.Response) error {
|
modifyResponse: func(resp *Response) error {
|
||||||
mr.m.AddTraceResponse("before modify response", resp)
|
mr.m.AddTraceResponse("before modify response", resp)
|
||||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||||
mr.m.AddTraceResponse("after modify response", resp)
|
mr.m.AddTraceResponse("after modify response", resp)
|
||||||
|
|
|
@ -10,71 +10,74 @@ import (
|
||||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type varReplaceFunc func(string) string
|
type (
|
||||||
|
reqVarGetter func(*Request) string
|
||||||
|
respVarGetter func(*Response) string
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
|
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
|
||||||
reHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
|
reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
|
||||||
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
|
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
|
||||||
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func varSubsMap(req *Request, resp *Response) map[string]func() string {
|
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||||
reqHost, reqPort, err := net.SplitHostPort(req.Host)
|
"$req_method": func(req *Request) string { return req.Method },
|
||||||
if err != nil {
|
"$req_scheme": func(req *Request) string { return req.URL.Scheme },
|
||||||
reqHost = req.Host
|
"$req_host": func(req *Request) string {
|
||||||
}
|
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||||
reqAddr := reqHost
|
if err != nil {
|
||||||
if reqPort != "" {
|
return req.Host
|
||||||
reqAddr += ":" + reqPort
|
}
|
||||||
}
|
return reqHost
|
||||||
|
},
|
||||||
pairs := map[string]func() string{
|
"$req_port": func(req *Request) string {
|
||||||
"$req_method": func() string { return req.Method },
|
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||||
"$req_scheme": func() string { return req.URL.Scheme },
|
return reqPort
|
||||||
"$req_host": func() string { return reqHost },
|
},
|
||||||
"$req_port": func() string { return reqPort },
|
"$req_addr": func(req *Request) string { return req.Host },
|
||||||
"$req_addr": func() string { return reqAddr },
|
"$req_path": func(req *Request) string { return req.URL.Path },
|
||||||
"$req_path": func() string { return req.URL.Path },
|
"$req_query": func(req *Request) string { return req.URL.RawQuery },
|
||||||
"$req_query": func() string { return req.URL.RawQuery },
|
"$req_url": func(req *Request) string { return req.URL.String() },
|
||||||
"$req_url": func() string { return req.URL.String() },
|
"$req_uri": func(req *Request) string { return req.URL.RequestURI() },
|
||||||
"$req_uri": req.URL.RequestURI,
|
"$req_content_type": func(req *Request) string { return req.Header.Get("Content-Type") },
|
||||||
"$req_content_type": func() string { return req.Header.Get("Content-Type") },
|
"$req_content_length": func(req *Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||||
"$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) },
|
"$remote_addr": func(req *Request) string { return req.RemoteAddr },
|
||||||
"$remote_addr": func() string { return req.RemoteAddr },
|
"$upstream_scheme": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamScheme) },
|
||||||
}
|
"$upstream_host": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamHost) },
|
||||||
|
"$upstream_port": func(req *Request) string { return req.Header.Get(gphttp.HeaderUpstreamPort) },
|
||||||
if resp != nil {
|
"$upstream_addr": func(req *Request) string {
|
||||||
pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") }
|
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||||
pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") }
|
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||||
pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) }
|
if upPort != "" {
|
||||||
}
|
return upHost + ":" + upPort
|
||||||
|
}
|
||||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
return upHost
|
||||||
if upScheme == "" {
|
},
|
||||||
return pairs
|
"$upstream_url": func(req *Request) string {
|
||||||
}
|
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||||
|
if upScheme == "" {
|
||||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
return ""
|
||||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
}
|
||||||
upAddr := upHost
|
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||||
if upPort != "" {
|
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||||
upAddr += ":" + upPort
|
upAddr := upHost
|
||||||
}
|
if upPort != "" {
|
||||||
upURL := upScheme + "://" + upAddr
|
upAddr += ":" + upPort
|
||||||
|
}
|
||||||
pairs["$upstream_scheme"] = func() string { return upScheme }
|
return upScheme + "://" + upAddr
|
||||||
pairs["$upstream_host"] = func() string { return upHost }
|
},
|
||||||
pairs["$upstream_port"] = func() string { return upPort }
|
|
||||||
pairs["$upstream_addr"] = func() string { return upAddr }
|
|
||||||
pairs["$upstream_url"] = func() string { return upURL }
|
|
||||||
|
|
||||||
return pairs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func varReplacer(req *Request, resp *Response) varReplaceFunc {
|
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||||
pairs := varSubsMap(req, resp)
|
"$resp_content_type": func(resp *Response) string { return resp.Header.Get("Content-Type") },
|
||||||
return func(s string) string {
|
"$resp_content_length": func(resp *Response) string { return resp.Header.Get("Content-Length") },
|
||||||
|
"$status_code": func(resp *Response) string { return strconv.Itoa(resp.StatusCode) },
|
||||||
|
}
|
||||||
|
|
||||||
|
func varReplace(req *Request, resp *Response, s string) string {
|
||||||
|
if req != nil {
|
||||||
// Replace query parameters
|
// Replace query parameters
|
||||||
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
name := match[5 : len(match)-1]
|
name := match[5 : len(match)-1]
|
||||||
|
@ -86,29 +89,39 @@ func varReplacer(req *Request, resp *Response) varReplaceFunc {
|
||||||
return ""
|
return ""
|
||||||
})
|
})
|
||||||
|
|
||||||
// Replace headers
|
// Replace request headers
|
||||||
s = reHeader.ReplaceAllStringFunc(s, func(match string) string {
|
s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
|
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
|
||||||
return req.Header.Get(header)
|
return req.Header.Get(header)
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
// Replace response headers
|
||||||
header := http.CanonicalHeaderKey(match[14 : len(match)-1])
|
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
return resp.Header.Get(header)
|
header := http.CanonicalHeaderKey(match[14 : len(match)-1])
|
||||||
})
|
return resp.Header.Get(header)
|
||||||
}
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Replace static variables
|
// Replace static variables
|
||||||
return reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
if req != nil {
|
||||||
if fn, ok := pairs[match]; ok {
|
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
return fn()
|
if fn, ok := staticReqVarSubsMap[match]; ok {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
return match
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil {
|
||||||
|
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||||
|
if fn, ok := staticRespVarSubsMap[match]; ok {
|
||||||
|
return fn(resp)
|
||||||
}
|
}
|
||||||
return match
|
return match
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func varReplacerDummy(s string) string {
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue