mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
support variables in modify request,response middlewares
This commit is contained in:
parent
cd749ac6a4
commit
ef1863f810
15 changed files with 291 additions and 114 deletions
|
@ -12,6 +12,10 @@ const (
|
|||
HeaderXForwardedPort = "X-Forwarded-Port"
|
||||
HeaderXForwardedURI = "X-Forwarded-Uri"
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
|
||||
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
|
||||
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
|
||||
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
|
||||
)
|
||||
|
||||
func RemoveHop(h http.Header) {
|
||||
|
|
|
@ -164,7 +164,15 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
|
|||
}
|
||||
|
||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||
mid := BuildMiddlewareFromChain(rp.TargetName, middlewares)
|
||||
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
|
||||
name: "set_upstream_headers",
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
|
||||
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
|
||||
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
|
||||
next(w, r)
|
||||
},
|
||||
}}, middlewares...))
|
||||
|
||||
if mid.before != nil {
|
||||
ori := rp.HandlerFunc
|
||||
|
|
|
@ -19,12 +19,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
|
|||
eb.Add(err)
|
||||
return nil
|
||||
}
|
||||
mids := BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
results := make(map[string]*Middleware, len(mids))
|
||||
for k, v := range mids {
|
||||
results[k+"@file"] = v
|
||||
}
|
||||
return results
|
||||
return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
|
||||
|
@ -40,7 +35,7 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
|
|||
if err != nil {
|
||||
eb.Add(err.Subject(source))
|
||||
} else {
|
||||
middlewares[name] = chain
|
||||
middlewares[name+"@file"] = chain
|
||||
}
|
||||
}
|
||||
return middlewares
|
||||
|
|
|
@ -38,17 +38,23 @@ func init() {
|
|||
// snakes and cases will be stripped on `Get`
|
||||
// so keys are lowercase without snake.
|
||||
allMiddlewares = map[string]*Middleware{
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
"modifyresponse": ModifyResponse,
|
||||
"modifyrequest": ModifyRequest,
|
||||
"errorpage": CustomErrorPage,
|
||||
"customerrorpage": CustomErrorPage,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
|
||||
"request": ModifyRequest,
|
||||
"modifyrequest": ModifyRequest,
|
||||
"response": ModifyResponse,
|
||||
"modifyresponse": ModifyResponse,
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
|
||||
"errorpage": CustomErrorPage,
|
||||
"customerrorpage": CustomErrorPage,
|
||||
|
||||
"realip": RealIP,
|
||||
"cloudflarerealip": CloudflareRealIP,
|
||||
"cidrwhitelist": CIDRWhiteList,
|
||||
"ratelimit": RateLimiter,
|
||||
|
||||
"cidrwhitelist": CIDRWhiteList,
|
||||
"ratelimit": RateLimiter,
|
||||
|
||||
// !experimental
|
||||
"forwardauth": ForwardAuth,
|
||||
|
|
|
@ -2,15 +2,16 @@ package middleware
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyRequest struct {
|
||||
modifyRequestOpts
|
||||
m *Middleware
|
||||
m *Middleware
|
||||
needVarSubstitution bool
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyRequestOpts struct {
|
||||
|
@ -24,38 +25,49 @@ var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
|
|||
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyRequest)
|
||||
mrFunc := mr.modifyRequest
|
||||
if common.IsDebug {
|
||||
mrFunc = mr.modifyRequestWithTrace
|
||||
}
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
before: Rewrite(mrFunc),
|
||||
impl: mr,
|
||||
before: Rewrite(func(req *Request) {
|
||||
mr.m.AddTraceRequest("before modify request", req)
|
||||
mr.modifyHeaders(req, nil, req.Header)
|
||||
mr.m.AddTraceRequest("after modify request", req)
|
||||
}),
|
||||
}
|
||||
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mr.checkVarSubstitution()
|
||||
return mr.m, nil
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) modifyRequest(req *Request) {
|
||||
for k, v := range mr.SetHeaders {
|
||||
if http.CanonicalHeaderKey(k) == "Host" {
|
||||
req.Host = v
|
||||
func (mr *modifyRequest) checkVarSubstitution() {
|
||||
for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
|
||||
for _, v := range m {
|
||||
if strings.Contains(v, "$") {
|
||||
mr.needVarSubstitution = true
|
||||
return
|
||||
}
|
||||
}
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
for _, k := range mr.HideHeaders {
|
||||
req.Header.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
|
||||
mr.m.AddTraceRequest("before modify request", req)
|
||||
mr.modifyRequest(req)
|
||||
mr.m.AddTraceRequest("after modify request", req)
|
||||
func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
|
||||
replaceVars := varReplacerDummy
|
||||
if mr.needVarSubstitution {
|
||||
replaceVars = varReplacer(req, resp)
|
||||
}
|
||||
|
||||
for k, v := range mr.SetHeaders {
|
||||
if strings.ToLower(k) == "host" {
|
||||
req.Host = replaceVars(v)
|
||||
}
|
||||
headers.Set(k, replaceVars(v))
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
headers.Add(k, replaceVars(v))
|
||||
}
|
||||
for _, k := range mr.HideHeaders {
|
||||
headers.Del(k)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,39 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetModifyRequest(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"set_headers": map[string]string{
|
||||
"User-Agent": "go-proxy/v0.5.0",
|
||||
"Host": "test.example.com",
|
||||
"User-Agent": "go-proxy/v0.5.0",
|
||||
"Host": "$upstream_addr",
|
||||
"X-Test-Req-Method": "$req_method",
|
||||
"X-Test-Req-Scheme": "$req_scheme",
|
||||
"X-Test-Req-Host": "$req_host",
|
||||
"X-Test-Req-Port": "$req_port",
|
||||
"X-Test-Req-Addr": "$req_addr",
|
||||
"X-Test-Req-Path": "$req_path",
|
||||
"X-Test-Req-Query": "$req_query",
|
||||
"X-Test-Req-Url": "$req_url",
|
||||
"X-Test-Req-Uri": "$req_uri",
|
||||
"X-Test-Req-Content-Type": "$req_content_type",
|
||||
"X-Test-Req-Content-Length": "$req_content_length",
|
||||
"X-Test-Remote-Addr": "$remote_addr",
|
||||
"X-Test-Upstream-Scheme": "$upstream_scheme",
|
||||
"X-Test-Upstream-Host": "$upstream_host",
|
||||
"X-Test-Upstream-Port": "$upstream_port",
|
||||
"X-Test-Upstream-Addr": "$upstream_addr",
|
||||
"X-Test-Upstream-Url": "$upstream_url",
|
||||
"X-Test-Content-Type": "$header(Content-Type)",
|
||||
"X-Test-Arg-Arg_1": "$arg(arg_1)",
|
||||
},
|
||||
"add_headers": map[string]string{"Accept-Encoding": "test-value"},
|
||||
"hide_headers": []string{"Accept"},
|
||||
|
@ -26,13 +48,44 @@ func TestSetModifyRequest(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("request_headers", func(t *testing.T) {
|
||||
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
|
||||
upstreamURL := types.MustParseURL("http://test.example.com")
|
||||
result, err := newMiddlewareTest(ModifyRequest, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
reqURL: reqURL,
|
||||
upstreamURL: upstreamURL,
|
||||
body: bytes.Repeat([]byte("a"), 100),
|
||||
headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
|
||||
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json")
|
||||
|
||||
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,52 +3,28 @@ package middleware
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type (
|
||||
modifyResponse struct {
|
||||
modifyResponseOpts
|
||||
m *Middleware
|
||||
}
|
||||
// order: set_headers -> add_headers -> hide_headers
|
||||
modifyResponseOpts = modifyRequestOpts
|
||||
)
|
||||
type modifyResponse = modifyRequest
|
||||
|
||||
var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
|
||||
|
||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
mr := new(modifyResponse)
|
||||
mr.m = &Middleware{impl: mr}
|
||||
if common.IsDebug {
|
||||
mr.m.modifyResponse = mr.modifyResponseWithTrace
|
||||
} else {
|
||||
mr.m.modifyResponse = mr.modifyResponse
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
modifyResponse: func(resp *http.Response) error {
|
||||
mr.m.AddTraceResponse("before modify response", resp)
|
||||
mr.modifyHeaders(resp.Request, resp, resp.Header)
|
||||
mr.m.AddTraceResponse("after modify response", resp)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
err := Deserialize(optsRaw, &mr.modifyResponseOpts)
|
||||
err := Deserialize(optsRaw, &mr.modifyRequestOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mr.checkVarSubstitution()
|
||||
return mr.m, nil
|
||||
}
|
||||
|
||||
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||
for k, v := range mr.SetHeaders {
|
||||
resp.Header.Set(k, v)
|
||||
}
|
||||
for k, v := range mr.AddHeaders {
|
||||
resp.Header.Add(k, v)
|
||||
}
|
||||
for _, k := range mr.HideHeaders {
|
||||
resp.Header.Del(k)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
|
||||
mr.m.AddTraceResponse("before modify response", resp)
|
||||
err := mr.modifyResponse(resp)
|
||||
mr.m.AddTraceResponse("after modify response", resp)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -5,21 +5,22 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRedirectToHTTPs(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
scheme: "http",
|
||||
reqURL: types.MustParseURL("http://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort)
|
||||
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort)
|
||||
}
|
||||
|
||||
func TestNoRedirect(t *testing.T) {
|
||||
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
|
||||
scheme: "https",
|
||||
reqURL: types.MustParseURL("https://example.com"),
|
||||
})
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
|
@ -19,8 +18,6 @@ import (
|
|||
var testHeadersRaw []byte
|
||||
var testHeaders http.Header
|
||||
|
||||
const testHost = "example.com"
|
||||
|
||||
func init() {
|
||||
if !common.IsTest {
|
||||
return
|
||||
|
@ -67,16 +64,16 @@ type TestResult struct {
|
|||
|
||||
type testArgs struct {
|
||||
middlewareOpt OptionsRaw
|
||||
proxyURL string
|
||||
reqURL types.URL
|
||||
upstreamURL types.URL
|
||||
body []byte
|
||||
scheme string
|
||||
realRoundTrip bool
|
||||
headers http.Header
|
||||
}
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||
var body io.Reader
|
||||
var rr requestRecorder
|
||||
var proxyURL *url.URL
|
||||
var requestTarget string
|
||||
var err error
|
||||
|
||||
if args == nil {
|
||||
|
@ -87,34 +84,24 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
|||
body = bytes.NewReader(args.body)
|
||||
}
|
||||
|
||||
switch args.scheme {
|
||||
case "":
|
||||
fallthrough
|
||||
case "http":
|
||||
requestTarget = "http://" + testHost
|
||||
case "https":
|
||||
requestTarget = "https://" + testHost
|
||||
default:
|
||||
panic("typo?")
|
||||
if args.reqURL.Nil() {
|
||||
args.reqURL = E.Must(types.ParseURL("https://example.com"))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, requestTarget, body)
|
||||
req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body)
|
||||
for k, v := range args.headers {
|
||||
req.Header[k] = v
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if args.scheme == "https" && req.TLS == nil {
|
||||
panic("bug occurred")
|
||||
if args.upstreamURL.Nil() {
|
||||
args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
|
||||
}
|
||||
|
||||
if args.proxyURL != "" {
|
||||
proxyURL, err = url.Parse(args.proxyURL)
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
if args.realRoundTrip {
|
||||
rr.parent = http.DefaultTransport
|
||||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gphttp.NewReverseProxy(middleware.name, types.NewURL(proxyURL), &rr)
|
||||
rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr)
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
|
|
|
@ -97,10 +97,16 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
|||
}
|
||||
|
||||
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return m.AddTracef("%s", msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return m.AddTracef("%s", msg).WithResponse(resp)
|
||||
}
|
||||
|
||||
|
|
106
internal/net/http/middleware/vars.go
Normal file
106
internal/net/http/middleware/vars.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
)
|
||||
|
||||
type varReplaceFunc func(string) string
|
||||
|
||||
var (
|
||||
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
|
||||
reHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
|
||||
reStatic = regexp.MustCompile(`\$[\w_]+`)
|
||||
)
|
||||
|
||||
func varSubsMap(req *Request, resp *Response) map[string]func() string {
|
||||
reqHost, reqPort, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
reqHost = req.Host
|
||||
}
|
||||
reqAddr := reqHost
|
||||
if reqPort != "" {
|
||||
reqAddr += ":" + reqPort
|
||||
}
|
||||
|
||||
pairs := map[string]func() string{
|
||||
"$req_method": func() string { return req.Method },
|
||||
"$req_scheme": func() string { return req.URL.Scheme },
|
||||
"$req_host": func() string { return reqHost },
|
||||
"$req_port": func() string { return reqPort },
|
||||
"$req_addr": func() string { return reqAddr },
|
||||
"$req_path": func() string { return req.URL.Path },
|
||||
"$req_query": func() string { return req.URL.RawQuery },
|
||||
"$req_url": func() string { return req.URL.String() },
|
||||
"$req_uri": req.URL.RequestURI,
|
||||
"$req_content_type": func() string { return req.Header.Get("Content-Type") },
|
||||
"$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
"$remote_addr": func() string { return req.RemoteAddr },
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") }
|
||||
pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") }
|
||||
pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) }
|
||||
}
|
||||
|
||||
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
|
||||
if upScheme == "" {
|
||||
return pairs
|
||||
}
|
||||
|
||||
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
|
||||
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
|
||||
upAddr := upHost
|
||||
if upPort != "" {
|
||||
upAddr += ":" + upPort
|
||||
}
|
||||
upURL := upScheme + "://" + upAddr
|
||||
|
||||
pairs["$upstream_scheme"] = func() string { return upScheme }
|
||||
pairs["$upstream_host"] = func() string { return upHost }
|
||||
pairs["$upstream_port"] = func() string { return upPort }
|
||||
pairs["$upstream_addr"] = func() string { return upAddr }
|
||||
pairs["$upstream_url"] = func() string { return upURL }
|
||||
|
||||
return pairs
|
||||
}
|
||||
|
||||
func varReplacer(req *Request, resp *Response) varReplaceFunc {
|
||||
pairs := varSubsMap(req, resp)
|
||||
return func(s string) string {
|
||||
// Replace query parameters
|
||||
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
|
||||
name := match[5 : len(match)-1]
|
||||
for k, v := range req.URL.Query() {
|
||||
if strings.EqualFold(k, name) {
|
||||
return v[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
// Replace headers
|
||||
s = reHeader.ReplaceAllStringFunc(s, func(match string) string {
|
||||
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
|
||||
return req.Header.Get(header)
|
||||
})
|
||||
|
||||
// Replace static variables
|
||||
return reStatic.ReplaceAllStringFunc(s, func(match string) string {
|
||||
if fn, ok := pairs[match]; ok {
|
||||
return fn()
|
||||
}
|
||||
return match
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func varReplacerDummy(s string) string {
|
||||
return s
|
||||
}
|
|
@ -264,6 +264,9 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
|
|||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
|
||||
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
|
||||
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
|
||||
p.HandlerFunc(rw, req)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,14 @@ type URL struct {
|
|||
*urlPkg.URL
|
||||
}
|
||||
|
||||
func MustParseURL(url string) URL {
|
||||
u, err := ParseURL(url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func ParseURL(url string) (URL, error) {
|
||||
u, err := urlPkg.Parse(url)
|
||||
if err != nil {
|
||||
|
@ -20,6 +28,10 @@ func NewURL(url *urlPkg.URL) URL {
|
|||
return URL{url}
|
||||
}
|
||||
|
||||
func (u URL) Nil() bool {
|
||||
return u.URL == nil
|
||||
}
|
||||
|
||||
func (u URL) String() string {
|
||||
if u.URL == nil {
|
||||
return "nil"
|
||||
|
|
|
@ -101,12 +101,21 @@ func TestWebhookBody(t *testing.T) {
|
|||
})
|
||||
ExpectNoError(t, err)
|
||||
|
||||
var body map[string][]map[string]any
|
||||
var body struct {
|
||||
Embeds []struct {
|
||||
Title string `json:"title"`
|
||||
Fields []struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
} `json:"fields"`
|
||||
} `json:"embeds"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(bodyReader).Decode(&body)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
ExpectEqual(t, body["embeds"][0]["title"], "abc")
|
||||
fields := ExpectType[[]map[string]any](t, body["embeds"][0]["fields"])
|
||||
ExpectEqual(t, fields[0]["name"], "foo")
|
||||
ExpectEqual(t, fields[0]["value"], "bar")
|
||||
ExpectEqual(t, body.Embeds[0].Title, "abc")
|
||||
fields := body.Embeds[0].Fields
|
||||
ExpectEqual(t, fields[0].Name, "foo")
|
||||
ExpectEqual(t, fields[0].Value, "bar")
|
||||
}
|
||||
|
|
|
@ -109,8 +109,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
|
|||
t.Helper()
|
||||
_, ok := got.(T)
|
||||
if !ok {
|
||||
t.Fatalf("expected type %s, got %s", reflect.TypeFor[T](), reflect.TypeOf(got).Elem())
|
||||
return
|
||||
t.Fatalf("expected type %s, got %T", reflect.TypeFor[T](), got)
|
||||
}
|
||||
return got.(T)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue