package middleware import ( "net" "net/http" "regexp" "strconv" "strings" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" ) type ( reqVarGetter func(*http.Request) string respVarGetter func(*http.Response) string ) var ( reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`) reReqHeader = regexp.MustCompile(`\$header\([\w-]+\)`) reRespHeader = regexp.MustCompile(`\$resp_header\([\w-]+\)`) reStatic = regexp.MustCompile(`\$[\w_]+`) ) const ( VarRequestMethod = "$req_method" VarRequestScheme = "$req_scheme" VarRequestHost = "$req_host" VarRequestPort = "$req_port" VarRequestPath = "$req_path" VarRequestAddr = "$req_addr" VarRequestQuery = "$req_query" VarRequestURL = "$req_url" VarRequestURI = "$req_uri" VarRequestContentType = "$req_content_type" VarRequestContentLen = "$req_content_length" VarRemoteHost = "$remote_host" VarRemotePort = "$remote_port" VarRemoteAddr = "$remote_addr" VarUpstreamName = "$upstream_name" VarUpstreamScheme = "$upstream_scheme" VarUpstreamHost = "$upstream_host" VarUpstreamPort = "$upstream_port" VarUpstreamAddr = "$upstream_addr" VarUpstreamURL = "$upstream_url" VarRespContentType = "$resp_content_type" VarRespContentLen = "$resp_content_length" VarRespStatusCode = "$status_code" ) var staticReqVarSubsMap = map[string]reqVarGetter{ VarRequestMethod: func(req *http.Request) string { return req.Method }, VarRequestScheme: func(req *http.Request) string { if req.TLS != nil { return "https" } return "http" }, VarRequestHost: func(req *http.Request) string { reqHost, _, err := net.SplitHostPort(req.Host) if err != nil { return req.Host } return reqHost }, VarRequestPort: func(req *http.Request) string { _, reqPort, _ := net.SplitHostPort(req.Host) return reqPort }, VarRequestAddr: func(req *http.Request) string { return req.Host }, VarRequestPath: func(req *http.Request) string { return req.URL.Path }, VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery }, VarRequestURL: func(req *http.Request) string { return req.URL.String() }, VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() }, VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") }, VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) }, VarRemoteHost: func(req *http.Request) string { clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err == nil { return clientIP } return "" }, VarRemotePort: func(req *http.Request) string { _, clientPort, err := net.SplitHostPort(req.RemoteAddr) if err == nil { return clientPort } return "" }, VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, VarUpstreamName: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamName) }, VarUpstreamScheme: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamScheme) }, VarUpstreamHost: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamHost) }, VarUpstreamPort: func(req *http.Request) string { return req.Header.Get(httpheaders.HeaderUpstreamPort) }, VarUpstreamAddr: func(req *http.Request) string { upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) if upPort != "" { return upHost + ":" + upPort } return upHost }, VarUpstreamURL: func(req *http.Request) string { upScheme := req.Header.Get(httpheaders.HeaderUpstreamScheme) if upScheme == "" { return "" } upHost := req.Header.Get(httpheaders.HeaderUpstreamHost) upPort := req.Header.Get(httpheaders.HeaderUpstreamPort) upAddr := upHost if upPort != "" { upAddr += ":" + upPort } return upScheme + "://" + upAddr }, } var staticRespVarSubsMap = map[string]respVarGetter{ VarRespContentType: func(resp *http.Response) string { return resp.Header.Get("Content-Type") }, VarRespContentLen: func(resp *http.Response) string { return strconv.FormatInt(resp.ContentLength, 10) }, VarRespStatusCode: func(resp *http.Response) string { return strconv.Itoa(resp.StatusCode) }, } func varReplace(req *http.Request, resp *http.Response, s string) string { if req != nil { // Replace query parameters s = reArg.ReplaceAllStringFunc(s, func(match string) string { name := match[5 : len(match)-1] for k, v := range req.URL.Query() { if strings.EqualFold(k, name) { return v[0] } } return "" }) // Replace request headers s = reReqHeader.ReplaceAllStringFunc(s, func(match string) string { header := http.CanonicalHeaderKey(match[8 : len(match)-1]) return req.Header.Get(header) }) } if resp != nil { // Replace response headers s = reRespHeader.ReplaceAllStringFunc(s, func(match string) string { header := http.CanonicalHeaderKey(match[13 : len(match)-1]) return resp.Header.Get(header) }) } // Replace static variables if req != nil { s = reStatic.ReplaceAllStringFunc(s, func(match string) string { if fn, ok := staticReqVarSubsMap[match]; ok { return fn(req) } return match }) } if resp != nil { s = reStatic.ReplaceAllStringFunc(s, func(match string) string { if fn, ok := staticRespVarSubsMap[match]; ok { return fn(resp) } return match }) } return s }