From ef1863f810a625abd129c90fd0c6f60e364a1a3b Mon Sep 17 00:00:00 2001 From: yusing Date: Tue, 3 Dec 2024 10:20:18 +0800 Subject: [PATCH] support variables in modify request,response middlewares --- internal/net/http/header_utils.go | 4 + internal/net/http/middleware/middleware.go | 10 +- .../net/http/middleware/middleware_builder.go | 9 +- internal/net/http/middleware/middlewares.go | 24 ++-- .../net/http/middleware/modify_request.go | 58 ++++++---- .../http/middleware/modify_request_test.go | 57 +++++++++- .../net/http/middleware/modify_response.go | 46 ++------ .../net/http/middleware/redirect_http_test.go | 7 +- internal/net/http/middleware/test_utils.go | 41 +++---- internal/net/http/middleware/trace.go | 6 + internal/net/http/middleware/vars.go | 106 ++++++++++++++++++ internal/net/http/reverse_proxy_mod.go | 3 + internal/net/types/url.go | 12 ++ internal/notif/webhook_test.go | 19 +++- internal/utils/testing/testing.go | 3 +- 15 files changed, 291 insertions(+), 114 deletions(-) create mode 100644 internal/net/http/middleware/vars.go diff --git a/internal/net/http/header_utils.go b/internal/net/http/header_utils.go index f78677f..4becb8e 100644 --- a/internal/net/http/header_utils.go +++ b/internal/net/http/header_utils.go @@ -12,6 +12,10 @@ const ( HeaderXForwardedPort = "X-Forwarded-Port" HeaderXForwardedURI = "X-Forwarded-Uri" HeaderXRealIP = "X-Real-IP" + + HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme" + HeaderUpstreamHost = "X-GoDoxy-Upstream-Host" + HeaderUpstreamPort = "X-GoDoxy-Upstream-Port" ) func RemoveHop(h http.Header) { diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 72899bb..c6facf8 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -164,7 +164,15 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) ( } func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { - mid := BuildMiddlewareFromChain(rp.TargetName, middlewares) + mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{ + name: "set_upstream_headers", + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { + r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme) + r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname()) + r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port()) + next(w, r) + }, + }}, middlewares...)) if mid.before != nil { ori := rp.HandlerFunc diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index eec7c9e..b26a3fb 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -19,12 +19,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string] eb.Add(err) return nil } - mids := BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) - results := make(map[string]*Middleware, len(mids)) - for k, v := range mids { - results[k+"@file"] = v - } - return results + return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) } func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware { @@ -40,7 +35,7 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str if err != nil { eb.Add(err.Subject(source)) } else { - middlewares[name] = chain + middlewares[name+"@file"] = chain } } return middlewares diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 1a7bff4..c3195e8 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -38,17 +38,23 @@ func init() { // snakes and cases will be stripped on `Get` // so keys are lowercase without snake. allMiddlewares = map[string]*Middleware{ - "setxforwarded": SetXForwarded, - "hidexforwarded": HideXForwarded, - "redirecthttp": RedirectHTTP, - "modifyresponse": ModifyResponse, - "modifyrequest": ModifyRequest, - "errorpage": CustomErrorPage, - "customerrorpage": CustomErrorPage, + "redirecthttp": RedirectHTTP, + + "request": ModifyRequest, + "modifyrequest": ModifyRequest, + "response": ModifyResponse, + "modifyresponse": ModifyResponse, + "setxforwarded": SetXForwarded, + "hidexforwarded": HideXForwarded, + + "errorpage": CustomErrorPage, + "customerrorpage": CustomErrorPage, + "realip": RealIP, "cloudflarerealip": CloudflareRealIP, - "cidrwhitelist": CIDRWhiteList, - "ratelimit": RateLimiter, + + "cidrwhitelist": CIDRWhiteList, + "ratelimit": RateLimiter, // !experimental "forwardauth": ForwardAuth, diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index 89b0622..95037a2 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -2,15 +2,16 @@ package middleware import ( "net/http" + "strings" - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" ) type ( modifyRequest struct { modifyRequestOpts - m *Middleware + m *Middleware + needVarSubstitution bool } // order: set_headers -> add_headers -> hide_headers modifyRequestOpts struct { @@ -24,38 +25,49 @@ var ModifyRequest = &Middleware{withOptions: NewModifyRequest} func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyRequest) - mrFunc := mr.modifyRequest - if common.IsDebug { - mrFunc = mr.modifyRequestWithTrace - } mr.m = &Middleware{ - impl: mr, - before: Rewrite(mrFunc), + impl: mr, + before: Rewrite(func(req *Request) { + mr.m.AddTraceRequest("before modify request", req) + mr.modifyHeaders(req, nil, req.Header) + mr.m.AddTraceRequest("after modify request", req) + }), } err := Deserialize(optsRaw, &mr.modifyRequestOpts) if err != nil { return nil, err } + mr.checkVarSubstitution() return mr.m, nil } -func (mr *modifyRequest) modifyRequest(req *Request) { - for k, v := range mr.SetHeaders { - if http.CanonicalHeaderKey(k) == "Host" { - req.Host = v +func (mr *modifyRequest) checkVarSubstitution() { + for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} { + for _, v := range m { + if strings.Contains(v, "$") { + mr.needVarSubstitution = true + return + } } - req.Header.Set(k, v) - } - for k, v := range mr.AddHeaders { - req.Header.Add(k, v) - } - for _, k := range mr.HideHeaders { - req.Header.Del(k) } } -func (mr *modifyRequest) modifyRequestWithTrace(req *Request) { - mr.m.AddTraceRequest("before modify request", req) - mr.modifyRequest(req) - mr.m.AddTraceRequest("after modify request", req) +func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) { + replaceVars := varReplacerDummy + if mr.needVarSubstitution { + replaceVars = varReplacer(req, resp) + } + + for k, v := range mr.SetHeaders { + if strings.ToLower(k) == "host" { + req.Host = replaceVars(v) + } + headers.Set(k, replaceVars(v)) + } + for k, v := range mr.AddHeaders { + headers.Add(k, replaceVars(v)) + } + for _, k := range mr.HideHeaders { + headers.Del(k) + } } diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 3a16c47..83aade1 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -1,17 +1,39 @@ package middleware import ( + "bytes" + "net/http" "slices" "testing" + "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestSetModifyRequest(t *testing.T) { opts := OptionsRaw{ "set_headers": map[string]string{ - "User-Agent": "go-proxy/v0.5.0", - "Host": "test.example.com", + "User-Agent": "go-proxy/v0.5.0", + "Host": "$upstream_addr", + "X-Test-Req-Method": "$req_method", + "X-Test-Req-Scheme": "$req_scheme", + "X-Test-Req-Host": "$req_host", + "X-Test-Req-Port": "$req_port", + "X-Test-Req-Addr": "$req_addr", + "X-Test-Req-Path": "$req_path", + "X-Test-Req-Query": "$req_query", + "X-Test-Req-Url": "$req_url", + "X-Test-Req-Uri": "$req_uri", + "X-Test-Req-Content-Type": "$req_content_type", + "X-Test-Req-Content-Length": "$req_content_length", + "X-Test-Remote-Addr": "$remote_addr", + "X-Test-Upstream-Scheme": "$upstream_scheme", + "X-Test-Upstream-Host": "$upstream_host", + "X-Test-Upstream-Port": "$upstream_port", + "X-Test-Upstream-Addr": "$upstream_addr", + "X-Test-Upstream-Url": "$upstream_url", + "X-Test-Content-Type": "$header(Content-Type)", + "X-Test-Arg-Arg_1": "$arg(arg_1)", }, "add_headers": map[string]string{"Accept-Encoding": "test-value"}, "hide_headers": []string{"Accept"}, @@ -26,13 +48,44 @@ func TestSetModifyRequest(t *testing.T) { }) t.Run("request_headers", func(t *testing.T) { + reqURL := types.MustParseURL("https://my.app/?arg_1=b") + upstreamURL := types.MustParseURL("http://test.example.com") result, err := newMiddlewareTest(ModifyRequest, &testArgs{ middlewareOpt: opts, + reqURL: reqURL, + upstreamURL: upstreamURL, + body: bytes.Repeat([]byte("a"), 100), + headers: http.Header{ + "Content-Type": []string{"application/json"}, + }, }) ExpectNoError(t, err) ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com") ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") + + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET") + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname()) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port()) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String()) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI()) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json") + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100") + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr) + + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname()) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port()) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host) + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String()) + + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json") + + ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b") }) } diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index d38ab75..ea9b5b6 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -3,52 +3,28 @@ package middleware import ( "net/http" - "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" ) -type ( - modifyResponse struct { - modifyResponseOpts - m *Middleware - } - // order: set_headers -> add_headers -> hide_headers - modifyResponseOpts = modifyRequestOpts -) +type modifyResponse = modifyRequest var ModifyResponse = &Middleware{withOptions: NewModifyResponse} func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { mr := new(modifyResponse) - mr.m = &Middleware{impl: mr} - if common.IsDebug { - mr.m.modifyResponse = mr.modifyResponseWithTrace - } else { - mr.m.modifyResponse = mr.modifyResponse + mr.m = &Middleware{ + impl: mr, + modifyResponse: func(resp *http.Response) error { + mr.m.AddTraceResponse("before modify response", resp) + mr.modifyHeaders(resp.Request, resp, resp.Header) + mr.m.AddTraceResponse("after modify response", resp) + return nil + }, } - err := Deserialize(optsRaw, &mr.modifyResponseOpts) + err := Deserialize(optsRaw, &mr.modifyRequestOpts) if err != nil { return nil, err } + mr.checkVarSubstitution() return mr.m, nil } - -func (mr *modifyResponse) modifyResponse(resp *http.Response) error { - for k, v := range mr.SetHeaders { - resp.Header.Set(k, v) - } - for k, v := range mr.AddHeaders { - resp.Header.Add(k, v) - } - for _, k := range mr.HideHeaders { - resp.Header.Del(k) - } - return nil -} - -func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error { - mr.m.AddTraceResponse("before modify response", resp) - err := mr.modifyResponse(resp) - mr.m.AddTraceResponse("after modify response", resp) - return err -} diff --git a/internal/net/http/middleware/redirect_http_test.go b/internal/net/http/middleware/redirect_http_test.go index b591fc2..82c6581 100644 --- a/internal/net/http/middleware/redirect_http_test.go +++ b/internal/net/http/middleware/redirect_http_test.go @@ -5,21 +5,22 @@ import ( "testing" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/net/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestRedirectToHTTPs(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ - scheme: "http", + reqURL: types.MustParseURL("http://example.com"), }) ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect) - ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort) + ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort) } func TestNoRedirect(t *testing.T) { result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ - scheme: "https", + reqURL: types.MustParseURL("https://example.com"), }) ExpectNoError(t, err) ExpectEqual(t, result.ResponseStatus, http.StatusOK) diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index f1947aa..858fe60 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" @@ -19,8 +18,6 @@ import ( var testHeadersRaw []byte var testHeaders http.Header -const testHost = "example.com" - func init() { if !common.IsTest { return @@ -67,16 +64,16 @@ type TestResult struct { type testArgs struct { middlewareOpt OptionsRaw - proxyURL string + reqURL types.URL + upstreamURL types.URL body []byte - scheme string + realRoundTrip bool + headers http.Header } func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { var body io.Reader var rr requestRecorder - var proxyURL *url.URL - var requestTarget string var err error if args == nil { @@ -87,34 +84,24 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E body = bytes.NewReader(args.body) } - switch args.scheme { - case "": - fallthrough - case "http": - requestTarget = "http://" + testHost - case "https": - requestTarget = "https://" + testHost - default: - panic("typo?") + if args.reqURL.Nil() { + args.reqURL = E.Must(types.ParseURL("https://example.com")) } - req := httptest.NewRequest(http.MethodGet, requestTarget, body) + req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body) + for k, v := range args.headers { + req.Header[k] = v + } w := httptest.NewRecorder() - if args.scheme == "https" && req.TLS == nil { - panic("bug occurred") + if args.upstreamURL.Nil() { + args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect } - if args.proxyURL != "" { - proxyURL, err = url.Parse(args.proxyURL) - if err != nil { - return nil, E.From(err) - } + if args.realRoundTrip { rr.parent = http.DefaultTransport - } else { - proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect } - rp := gphttp.NewReverseProxy(middleware.name, types.NewURL(proxyURL), &rr) + rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr) mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go index b1ad036..1a8e8e8 100644 --- a/internal/net/http/middleware/trace.go +++ b/internal/net/http/middleware/trace.go @@ -97,10 +97,16 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace { } func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace { + if !m.trace { + return nil + } return m.AddTracef("%s", msg).WithRequest(req) } func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace { + if !m.trace { + return nil + } return m.AddTracef("%s", msg).WithResponse(resp) } diff --git a/internal/net/http/middleware/vars.go b/internal/net/http/middleware/vars.go new file mode 100644 index 0000000..74bcbcf --- /dev/null +++ b/internal/net/http/middleware/vars.go @@ -0,0 +1,106 @@ +package middleware + +import ( + "net" + "net/http" + "regexp" + "strconv" + "strings" + + gphttp "github.com/yusing/go-proxy/internal/net/http" +) + +type varReplaceFunc func(string) string + +var ( + reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`) + reHeader = regexp.MustCompile(`\$header\([\w-]+\)`) + reStatic = regexp.MustCompile(`\$[\w_]+`) +) + +func varSubsMap(req *Request, resp *Response) map[string]func() string { + reqHost, reqPort, err := net.SplitHostPort(req.Host) + if err != nil { + reqHost = req.Host + } + reqAddr := reqHost + if reqPort != "" { + reqAddr += ":" + reqPort + } + + pairs := map[string]func() string{ + "$req_method": func() string { return req.Method }, + "$req_scheme": func() string { return req.URL.Scheme }, + "$req_host": func() string { return reqHost }, + "$req_port": func() string { return reqPort }, + "$req_addr": func() string { return reqAddr }, + "$req_path": func() string { return req.URL.Path }, + "$req_query": func() string { return req.URL.RawQuery }, + "$req_url": func() string { return req.URL.String() }, + "$req_uri": req.URL.RequestURI, + "$req_content_type": func() string { return req.Header.Get("Content-Type") }, + "$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) }, + "$remote_addr": func() string { return req.RemoteAddr }, + } + + if resp != nil { + pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") } + pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") } + pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) } + } + + upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme) + if upScheme == "" { + return pairs + } + + upHost := req.Header.Get(gphttp.HeaderUpstreamHost) + upPort := req.Header.Get(gphttp.HeaderUpstreamPort) + upAddr := upHost + if upPort != "" { + upAddr += ":" + upPort + } + upURL := upScheme + "://" + upAddr + + pairs["$upstream_scheme"] = func() string { return upScheme } + pairs["$upstream_host"] = func() string { return upHost } + pairs["$upstream_port"] = func() string { return upPort } + pairs["$upstream_addr"] = func() string { return upAddr } + pairs["$upstream_url"] = func() string { return upURL } + + return pairs +} + +func varReplacer(req *Request, resp *Response) varReplaceFunc { + pairs := varSubsMap(req, resp) + return func(s string) string { + // Replace query parameters + s = reArg.ReplaceAllStringFunc(s, func(match string) string { + name := match[5 : len(match)-1] + for k, v := range req.URL.Query() { + if strings.EqualFold(k, name) { + return v[0] + } + } + return "" + }) + + // Replace headers + s = reHeader.ReplaceAllStringFunc(s, func(match string) string { + header := http.CanonicalHeaderKey(match[8 : len(match)-1]) + return req.Header.Get(header) + }) + + // Replace static variables + return reStatic.ReplaceAllStringFunc(s, func(match string) string { + if fn, ok := pairs[match]; ok { + return fn() + } + return match + }) + } +} + +func varReplacerDummy(s string) string { + return s +} diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index d06f42b..7070525 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -264,6 +264,9 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response } func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme) + // req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname()) + // req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port()) p.HandlerFunc(rw, req) } diff --git a/internal/net/types/url.go b/internal/net/types/url.go index 2d51956..eb984fa 100644 --- a/internal/net/types/url.go +++ b/internal/net/types/url.go @@ -8,6 +8,14 @@ type URL struct { *urlPkg.URL } +func MustParseURL(url string) URL { + u, err := ParseURL(url) + if err != nil { + panic(err) + } + return u +} + func ParseURL(url string) (URL, error) { u, err := urlPkg.Parse(url) if err != nil { @@ -20,6 +28,10 @@ func NewURL(url *urlPkg.URL) URL { return URL{url} } +func (u URL) Nil() bool { + return u.URL == nil +} + func (u URL) String() string { if u.URL == nil { return "nil" diff --git a/internal/notif/webhook_test.go b/internal/notif/webhook_test.go index 85fc588..9fb31ab 100644 --- a/internal/notif/webhook_test.go +++ b/internal/notif/webhook_test.go @@ -101,12 +101,21 @@ func TestWebhookBody(t *testing.T) { }) ExpectNoError(t, err) - var body map[string][]map[string]any + var body struct { + Embeds []struct { + Title string `json:"title"` + Fields []struct { + Name string `json:"name"` + Value string `json:"value"` + } `json:"fields"` + } `json:"embeds"` + } + err = json.NewDecoder(bodyReader).Decode(&body) ExpectNoError(t, err) - ExpectEqual(t, body["embeds"][0]["title"], "abc") - fields := ExpectType[[]map[string]any](t, body["embeds"][0]["fields"]) - ExpectEqual(t, fields[0]["name"], "foo") - ExpectEqual(t, fields[0]["value"], "bar") + ExpectEqual(t, body.Embeds[0].Title, "abc") + fields := body.Embeds[0].Fields + ExpectEqual(t, fields[0].Name, "foo") + ExpectEqual(t, fields[0].Value, "bar") } diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index 8b3217a..e8cecc7 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -109,8 +109,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) { t.Helper() _, ok := got.(T) if !ok { - t.Fatalf("expected type %s, got %s", reflect.TypeFor[T](), reflect.TypeOf(got).Elem()) - return + t.Fatalf("expected type %s, got %T", reflect.TypeFor[T](), got) } return got.(T) }