GoDoxy/internal/net/gphttp/middleware/vars.go
2025-05-11 06:33:22 +08:00

156 lines
4.6 KiB
Go

package middleware
import (
"net"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/yusing/go-proxy/internal/route/routes"
)
type (
reqVarGetter func(*http.Request) string
respVarGetter func(*http.Response) string
)
var (
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`)
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
const (
VarRequestMethod = "$req_method"
VarRequestScheme = "$req_scheme"
VarRequestHost = "$req_host"
VarRequestPort = "$req_port"
VarRequestPath = "$req_path"
VarRequestAddr = "$req_addr"
VarRequestQuery = "$req_query"
VarRequestURL = "$req_url"
VarRequestURI = "$req_uri"
VarRequestContentType = "$req_content_type"
VarRequestContentLen = "$req_content_length"
VarRemoteHost = "$remote_host"
VarRemotePort = "$remote_port"
VarRemoteAddr = "$remote_addr"
VarUpstreamName = "$upstream_name"
VarUpstreamScheme = "$upstream_scheme"
VarUpstreamHost = "$upstream_host"
VarUpstreamPort = "$upstream_port"
VarUpstreamAddr = "$upstream_addr"
VarUpstreamURL = "$upstream_url"
VarRespContentType = "$resp_content_type"
VarRespContentLen = "$resp_content_length"
VarRespStatusCode = "$status_code"
)
var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return ""
},
VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientPort
}
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: routes.TryGetUpstreamName,
VarUpstreamScheme: routes.TryGetUpstreamScheme,
VarUpstreamHost: routes.TryGetUpstreamHost,
VarUpstreamPort: routes.TryGetUpstreamPort,
VarUpstreamAddr: routes.TryGetUpstreamAddr,
VarUpstreamURL: routes.TryGetUpstreamURL,
}
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") },
VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) },
VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) },
}
func varReplace(req *http.Request, resp *http.Response, s string) string {
if req != nil {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
name := match[5 : len(match)-1]
for k, v := range req.URL.Query() {
if strings.EqualFold(k, name) {
return v[0]
}
}
return ""
})
// Replace request headers
s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
return req.Header.Get(header)
})
}
if resp != nil {
// Replace response headers
s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[13 : len(match)-1])
return resp.Header.Get(header)
})
}
// Replace static variables
if req != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticReqVarSubsMap[match]; ok {
return fn(req)
}
return match
})
}
if resp != nil {
s = reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := staticRespVarSubsMap[match]; ok {
return fn(resp)
}
return match
})
}
return s
}